From eef845f0d38102de82e0b183ea78f1e6aac30576 Mon Sep 17 00:00:00 2001 From: Nikita Volkov Date: Sat, 27 Jan 2024 02:28:43 +0300 Subject: [PATCH] Adapt the params encoder --- CHANGELOG.md | 5 ++ library/Hasql/Encoders/All.hs | 2 +- library/Hasql/Encoders/Params.hs | 77 +++++++++++++++++++++++------ library/Hasql/IO.hs | 32 ++++++------ library/Hasql/Session/Core.hs | 8 ++- library/Hasql/Statement/Function.hs | 5 +- tasty/Main.hs | 26 ---------- 7 files changed, 92 insertions(+), 63 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a96c1de..5cdd40f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,8 @@ +# 1.7 + +- Added `Statement.function` for easier integration with stored procedures. +- Decidable instance on `Encoders.Params` removed. It was useless and limited the design. + # 1.6.3.1 - Moved to "postgresql-libpq-0.10" diff --git a/library/Hasql/Encoders/All.hs b/library/Hasql/Encoders/All.hs index 8691e81..c2c4d24 100644 --- a/library/Hasql/Encoders/All.hs +++ b/library/Hasql/Encoders/All.hs @@ -61,7 +61,7 @@ import qualified Text.Builder as C -- Female -> "female" -- @ newtype Params a = Params (Params.Params a) - deriving (Contravariant, Divisible, Decidable, Monoid, Semigroup) + deriving (Contravariant, Divisible, Monoid, Semigroup) -- | -- No parameters. Same as `mempty` and `conquered`. diff --git a/library/Hasql/Encoders/Params.hs b/library/Hasql/Encoders/Params.hs index 63715cf..51e2cf8 100644 --- a/library/Hasql/Encoders/Params.hs +++ b/library/Hasql/Encoders/Params.hs @@ -9,23 +9,70 @@ import qualified Text.Builder as E -- | -- Encoder of some representation of a parameters product. -newtype Params a - = Params (Op (DList (A.Oid, A.Format, Bool -> Maybe ByteString, Text)) a) - deriving (Contravariant, Divisible, Decidable, Semigroup, Monoid) +data Params a = Params + { size :: !Int, + columnsMetadata :: !(DList (A.Oid, A.Format)), + serializer :: Bool -> a -> DList (Maybe ByteString), + printer :: a -> DList Text + } + +instance Contravariant Params where + contramap fn (Params size columnsMetadata oldSerializer oldPrinter) = Params {..} + where + serializer idt = oldSerializer idt . fn + printer = oldPrinter . fn + +instance Divisible Params where + divide + divisor + (Params leftSize leftColumnsMetadata leftSerializer leftPrinter) + (Params rightSize rightColumnsMetadata rightSerializer rightPrinter) = + Params + { size = leftSize + rightSize, + columnsMetadata = leftColumnsMetadata <> rightColumnsMetadata, + serializer = \idt input -> case divisor input of + (leftInput, rightInput) -> leftSerializer idt leftInput <> rightSerializer idt rightInput, + printer = \input -> case divisor input of + (leftInput, rightInput) -> leftPrinter leftInput <> rightPrinter rightInput + } + conquer = + Params + { size = 0, + columnsMetadata = mempty, + serializer = mempty, + printer = mempty + } + +instance Semigroup (Params a) where + Params leftSize leftColumnsMetadata leftSerializer leftPrinter <> Params rightSize rightColumnsMetadata rightSerializer rightPrinter = + Params + { size = leftSize + rightSize, + columnsMetadata = leftColumnsMetadata <> rightColumnsMetadata, + serializer = \idt input -> leftSerializer idt input <> rightSerializer idt input, + printer = \input -> leftPrinter input <> rightPrinter input + } + +instance Monoid (Params a) where + mempty = conquer value :: C.Value a -> Params a -value = - contramap Just . nullableValue +value (C.Value valueOID _ serialize print) = + Params + { size = 1, + columnsMetadata = pure (pqOid, format), + serializer = \idt -> pure . Just . B.encodingBytes . serialize idt, + printer = pure . E.run . print + } + where + D.OID _ pqOid format = valueOID nullableValue :: C.Value a -> Params (Maybe a) -nullableValue (C.Value valueOID arrayOID encode render) = +nullableValue (C.Value valueOID _ serialize print) = Params - $ Op - $ \input -> - let D.OID _ pqOid format = - valueOID - encoder env = - fmap (B.encodingBytes . encode env) input - rendering = - maybe "null" (E.run . render) input - in pure (pqOid, format, encoder, rendering) + { size = 1, + columnsMetadata = pure (pqOid, format), + serializer = \idt -> pure . fmap (B.encodingBytes . serialize idt), + printer = pure . maybe "null" (E.run . print) + } + where + D.OID _ pqOid format = valueOID diff --git a/library/Hasql/IO.hs b/library/Hasql/IO.hs index 47c4b3a..db9dd94 100644 --- a/library/Hasql/IO.hs +++ b/library/Hasql/IO.hs @@ -114,16 +114,17 @@ sendPreparedParametricStatement :: ParamsEncoders.Params a -> a -> IO (Either CommandError ()) -sendPreparedParametricStatement connection registry integerDatetimes template (ParamsEncoders.Params (Op encoderOp)) input = - let (oidList, valueAndFormatList) = - let step (oid, format, encoder, _) ~(oidList, bytesAndFormatList) = - (,) - (oid : oidList) - (fmap (\bytes -> (bytes, format)) (encoder integerDatetimes) : bytesAndFormatList) - in foldr step ([], []) (encoderOp input) - in runExceptT $ do - key <- ExceptT $ getPreparedStatementKey connection registry template oidList - ExceptT $ checkedSend connection $ LibPQ.sendQueryPrepared connection key valueAndFormatList LibPQ.Binary +sendPreparedParametricStatement connection registry integerDatetimes template (ParamsEncoders.Params size columnsMetadata serializer _) input = + runExceptT $ do + key <- ExceptT $ getPreparedStatementKey connection registry template oidList + ExceptT $ checkedSend connection $ LibPQ.sendQueryPrepared connection key valueAndFormatList LibPQ.Binary + where + (oidList, formatList) = + columnsMetadata & toList & unzip + valueAndFormatList = + serializer integerDatetimes input + & toList + & zipWith (\format encoding -> (,format) <$> encoding) formatList {-# INLINE sendUnpreparedParametricStatement #-} sendUnpreparedParametricStatement :: @@ -133,11 +134,14 @@ sendUnpreparedParametricStatement :: ParamsEncoders.Params a -> a -> IO (Either CommandError ()) -sendUnpreparedParametricStatement connection integerDatetimes template (ParamsEncoders.Params (Op encoderOp)) input = +sendUnpreparedParametricStatement connection integerDatetimes template (ParamsEncoders.Params _ columnsMetadata serializer printer) input = let params = - let step (oid, format, encoder, _) acc = - ((,,) <$> pure oid <*> encoder integerDatetimes <*> pure format) : acc - in foldr step [] (encoderOp input) + zipWith + ( \(oid, format) encoding -> + (,,) <$> pure oid <*> encoding <*> pure format + ) + (toList columnsMetadata) + (toList (serializer integerDatetimes input)) in checkedSend connection $ LibPQ.sendQueryParams connection template params LibPQ.Binary {-# INLINE sendParametricStatement #-} diff --git a/library/Hasql/Session/Core.hs b/library/Hasql/Session/Core.hs index 09f1ffe..febafd8 100644 --- a/library/Hasql/Session/Core.hs +++ b/library/Hasql/Session/Core.hs @@ -46,7 +46,7 @@ sql sql = -- | -- Parameters and a specification of a parametric single-statement query to apply them to. statement :: params -> Statement.Statement params result -> Session result -statement input (Statement.Statement template (Encoders.Params paramsEncoder) decoder preparable) = +statement input (Statement.Statement template (Encoders.Params paramsEncoder@(Encoders.Params.Params _ _ _ printer)) decoder preparable) = Session $ ReaderT $ \(Connection.Connection pqConnectionRef integerDatetimes registry) -> @@ -59,7 +59,5 @@ statement input (Statement.Statement template (Encoders.Params paramsEncoder) de return $ r1 *> r2 where inputReps = - let Encoders.Params.Params (Op encoderOp) = paramsEncoder - step (_, _, _, rendering) acc = - rendering : acc - in foldr step [] (encoderOp input) + printer input + & toList diff --git a/library/Hasql/Statement/Function.hs b/library/Hasql/Statement/Function.hs index 7d8e774..23d0e80 100644 --- a/library/Hasql/Statement/Function.hs +++ b/library/Hasql/Statement/Function.hs @@ -2,11 +2,12 @@ module Hasql.Statement.Function where import qualified ByteString.StrictBuilder as Builder import qualified Hasql.Encoders.All as Encoders +import qualified Hasql.Encoders.Params as Encoders.Params import Hasql.Prelude import qualified Hasql.Statement.Function.SqlBuilder as SqlBuilder sql :: Text -> Encoders.Params a -> ByteString -sql name encoders = +sql name (Encoders.Params (Encoders.Params.Params size _ _ _)) = Builder.builderBytes $ SqlBuilder.sql name - $ error "TODO: Get size" encoders + $ size diff --git a/tasty/Main.hs b/tasty/Main.hs index 36fbb42..560e79e 100644 --- a/tasty/Main.hs +++ b/tasty/Main.hs @@ -372,32 +372,6 @@ tree = Encoders.param (Encoders.nonNullable (Encoders.unknown)) in DSL.statement "ok" statement in actualIO >>= assertEqual "" (Right True), - testCase "Textual Unknown" - $ let actualIO = - DSL.session $ do - let statement = - Statement.Statement sql mempty Decoders.noResult True - where - sql = - "create or replace function overloaded(a int, b int) returns int as $$ select a + b $$ language sql;" - in DSL.statement () statement - let statement = - Statement.Statement sql mempty Decoders.noResult True - where - sql = - "create or replace function overloaded(a text, b text, c text) returns text as $$ select a || b || c $$ language sql;" - in DSL.statement () statement - let statement = - Statement.Statement sql encoder decoder True - where - sql = - "select overloaded($1, $2) || overloaded($3, $4, $5)" - decoder = - (Decoders.singleRow ((Decoders.column . Decoders.nonNullable) (Decoders.text))) - encoder = - contramany (Encoders.param (Encoders.nonNullable (Encoders.unknown))) - in DSL.statement ["1", "2", "4", "5", "6"] statement - in actualIO >>= assertEqual "" (Right "3456"), testCase "Enum" $ let actualIO = DSL.session $ do