mirror of
https://github.com/postgres-haskell/postgres-wire.git
synced 2024-11-22 05:53:12 +03:00
Change ByteString to Encode in Bind message
This commit is contained in:
parent
a2536fd5f4
commit
940cf38f9c
@ -82,13 +82,13 @@ encodeMessage params = runEncode $
|
||||
bindMessage = Bind (PortalName "") stmtName Binary
|
||||
(encodedParams params) Binary
|
||||
encodedParams (a, b, c, d, e, f, g) =
|
||||
[ Just . runEncode $ PE.bool a
|
||||
, Just . runEncode $ PE.bytea b
|
||||
, Just . runEncode $ PE.float8 c
|
||||
, Just . runEncode $ PE.interval d
|
||||
, Just . runEncode $ PE.numeric e
|
||||
, Just . runEncode $ PE.timestamptz f
|
||||
, Just . runEncode $ PE.uuid g
|
||||
[ Just $ PE.bool a
|
||||
, Just $ PE.bytea b
|
||||
, Just $ PE.float8 c
|
||||
, Just $ PE.interval d
|
||||
, Just $ PE.numeric e
|
||||
, Just $ PE.timestamptz f
|
||||
, Just $ PE.uuid g
|
||||
]
|
||||
parseMessage = Parse stmtName stmt oids
|
||||
stmtName = StatementName "_pw_statement_0010"
|
||||
|
@ -32,7 +32,7 @@ import Database.PostgreSQL.Driver.StatementStorage
|
||||
-- Public
|
||||
data Query = Query
|
||||
{ qStatement :: B.ByteString
|
||||
, qValues :: [(Oid, Maybe B.ByteString)]
|
||||
, qValues :: [(Oid, Maybe Encode)]
|
||||
, qParamsFormat :: Format
|
||||
, qResultFormat :: Format
|
||||
, qCachePolicy :: CachePolicy
|
||||
|
@ -80,10 +80,9 @@ encodeClientMessage Terminate
|
||||
-- | Encodes single data values. Length `-1` indicates a NULL parameter value.
|
||||
-- No value bytes follow in the NULL case.
|
||||
{-# INLINE encodeValue #-}
|
||||
encodeValue :: Maybe B.ByteString -> Encode
|
||||
encodeValue :: Maybe Encode -> Encode
|
||||
encodeValue Nothing = putWord32BE (-1)
|
||||
encodeValue (Just v) = putWord32BE (fromIntegral $ B.length v)
|
||||
<> putByteString v
|
||||
encodeValue (Just v) = putWord32BE (fromIntegral $ getEncodeLen v) <> v
|
||||
|
||||
{-# INLINE encodeFormat #-}
|
||||
encodeFormat :: Format -> Encode
|
||||
|
@ -20,6 +20,9 @@ instance Monoid Encode where
|
||||
{-# INLINE mappend #-}
|
||||
(Encode len1 f1) `mappend` (Encode len2 f2) = Encode (len1 + len2) (f1 *> f2)
|
||||
|
||||
instance Show Encode where
|
||||
show (Encode len _) = "Encode instance of length " ++ show len
|
||||
|
||||
{-# INLINE getEncodeLen #-}
|
||||
getEncodeLen :: Encode -> Int
|
||||
getEncodeLen (Encode len _) = len
|
||||
|
@ -16,6 +16,7 @@ import Data.Hashable (Hashable)
|
||||
import Data.ByteString as B(ByteString)
|
||||
import qualified Data.ByteString.Lazy as BL(ByteString)
|
||||
import Data.Vector (Vector)
|
||||
import Database.PostgreSQL.Protocol.Store.Encode (Encode)
|
||||
|
||||
-- Common
|
||||
newtype Oid = Oid { unOid :: Word32 } deriving (Show, Eq)
|
||||
@ -128,9 +129,9 @@ data AuthResponse
|
||||
data ClientMessage
|
||||
= Bind !PortalName !StatementName
|
||||
!Format -- parameter format code, one format for all
|
||||
![Maybe ByteString] -- the values of parameters, Nothing
|
||||
![Maybe Encode] -- the values of parameters, Nothing
|
||||
-- is recognized as NULL
|
||||
!Format -- to apply code to all result columns
|
||||
!Format -- to apply code to all result columns
|
||||
-- Postgres use one command `close` for closing both statements and
|
||||
-- portals, but we distinguish them
|
||||
| CloseStatement !StatementName
|
||||
|
@ -35,8 +35,7 @@ makeCodecProperty
|
||||
-> Oid -> (a -> Encode) -> PD.FieldDecoder a
|
||||
-> a -> Property
|
||||
makeCodecProperty c oid encoder fd v = monadicIO $ do
|
||||
let bs = runEncode $ encoder v
|
||||
q = Query "SELECT $1" [(oid, Just bs)]
|
||||
let q = Query "SELECT $1" [(oid, Just $ encoder v)]
|
||||
Binary Binary AlwaysCache
|
||||
decoder = PD.dataRowHeader *> PD.getNonNullable fd
|
||||
r <- run $ do
|
||||
@ -58,8 +57,7 @@ makeCodecEncodeProperty
|
||||
-> (a -> String)
|
||||
-> a -> Property
|
||||
makeCodecEncodeProperty c oid queryString encoder fPrint v = monadicIO $ do
|
||||
let bs = runEncode $ encoder v
|
||||
q = Query queryString [(oid, Just bs)]
|
||||
let q = Query queryString [(oid, Just $ encoder v)]
|
||||
Binary Text AlwaysCache
|
||||
decoder = PD.dataRowHeader *> PD.getNonNullable PD.bytea
|
||||
r <- run $ do
|
||||
|
@ -23,6 +23,7 @@ import Database.PostgreSQL.Protocol.Store.Decode
|
||||
import Database.PostgreSQL.Protocol.Decoders
|
||||
|
||||
import Database.PostgreSQL.Protocol.Codecs.Decoders
|
||||
import Database.PostgreSQL.Protocol.Codecs.Encoders as PE
|
||||
|
||||
import Connection
|
||||
|
||||
@ -45,11 +46,13 @@ testDriver = testGroup "Driver"
|
||||
]
|
||||
|
||||
makeQuery1 :: B.ByteString -> Query
|
||||
makeQuery1 n = Query "SELECT $1" [(Oid 23, Just n)] Text Text AlwaysCache
|
||||
makeQuery1 n = Query "SELECT $1" [(Oid 23, Just $ PE.bytea n )]
|
||||
Text Text AlwaysCache
|
||||
|
||||
makeQuery2 :: B.ByteString -> B.ByteString -> Query
|
||||
makeQuery2 n1 n2 = Query "SELECT $1 + $2"
|
||||
[(Oid 23, Just n1), (Oid 23, Just n2)] Text Text AlwaysCache
|
||||
[(Oid 23, Just $ PE.bytea n1), (Oid 23, Just $ PE.bytea n2)]
|
||||
Text Text AlwaysCache
|
||||
|
||||
fromRight :: Either e a -> a
|
||||
fromRight (Right v) = v
|
||||
@ -140,8 +143,10 @@ checkInvalidResult conn n = readNextData conn >>=
|
||||
testInvalidBatch :: IO ()
|
||||
testInvalidBatch = do
|
||||
let rightQuery = makeQuery1 "5"
|
||||
q1 = Query "SEL $1" [(Oid 23, Just "5")] Text Text NeverCache
|
||||
q2 = Query "SELECT $1" [(Oid 23, Just "a")] Text Text NeverCache
|
||||
q1 = Query "SEL $1" [(Oid 23, Just $ PE.bytea "5")]
|
||||
Text Text NeverCache
|
||||
q2 = Query "SELECT $1" [(Oid 23, Just $ PE.bytea "a")]
|
||||
Text Text NeverCache
|
||||
q4 = Query "SELECT $1" [] Text Text NeverCache
|
||||
|
||||
assertInvalidBatch "Parse error" [q1]
|
||||
|
@ -12,6 +12,7 @@ import Database.PostgreSQL.Driver.StatementStorage
|
||||
import Database.PostgreSQL.Driver.Query
|
||||
import Database.PostgreSQL.Driver.Error
|
||||
import Database.PostgreSQL.Protocol.Types
|
||||
import Database.PostgreSQL.Protocol.Codecs.Encoders as PE
|
||||
|
||||
import Connection
|
||||
|
||||
@ -50,7 +51,7 @@ testExtendedQuery = withConnectionCommonAll $ \c -> do
|
||||
statement = StatementSQL "SELECT $1 + $2"
|
||||
sendMessage rawConn $ Parse sname statement [Oid 23, Oid 23]
|
||||
sendMessage rawConn $
|
||||
Bind pname sname Text [Just "1", Just "2"] Text
|
||||
Bind pname sname Text [Just $ PE.bytea "1", Just $ PE.bytea "2"] Text
|
||||
sendMessage rawConn $ Execute pname noLimitToReceive
|
||||
sendMessage rawConn $ DescribeStatement sname
|
||||
sendMessage rawConn $ DescribePortal pname
|
||||
|
Loading…
Reference in New Issue
Block a user