Fixed encoders with newtypes

This commit is contained in:
VyacheslavHashov 2017-01-18 19:57:05 +03:00
parent b9b0098ac0
commit da8bba2c74
3 changed files with 27 additions and 22 deletions

View File

@ -17,21 +17,23 @@ currentVersion = 3 * 256 * 256
encodeStartMessage :: StartMessage -> Builder
-- Options except user and database are not supported
encodeStartMessage (StartupMessage uname dbname) =
encodeStartMessage (StartupMessage (Username uname) (DatabaseName dbname)) =
int32BE (len + 4) <> payload
where
len = fromIntegral $ BL.length $ toLazyByteString payload
payload = int32BE currentVersion <>
pgString "user" <> pgString uname <>
pgString "database" <> pgString dbname <> word8 0
-- TODO
encodeStartMessage SSLRequest = undefined
encodeClientMessage :: ClientMessage -> Builder
encodeClientMessage (Bind portalName stmtName paramFormat values resultFormat)
encodeClientMessage (Bind (PortalName portalName) (StatementName stmtName)
paramFormat values resultFormat)
= prependHeader 'B' $
pgString portalName <>
pgString stmtName <>
-- the specified format code is applied to all parameters
-- `1` means that the specified format code is applied to all parameters
int16BE 1 <>
encodeFormat paramFormat <>
int16BE (fromIntegral $ V.length values) <>
@ -39,18 +41,19 @@ encodeClientMessage (Bind portalName stmtName paramFormat values resultFormat)
-- follow in the NULL case.
fold ((\v -> int32BE (fromIntegral $ B.length v) <> byteString v)
<$> values) <>
-- the specified format code is applied to all result columns (if any)
-- `1` means that the specified format code is applied to all
-- result columns (if any)
int16BE 1 <>
encodeFormat resultFormat
encodeClientMessage (CloseStatement stmtName)
encodeClientMessage (CloseStatement (StatementName stmtName))
= prependHeader 'C' $ char8 'S' <> pgString stmtName
encodeClientMessage (ClosePortal portalName)
encodeClientMessage (ClosePortal (PortalName portalName))
= prependHeader 'C' $ char8 'P' <> pgString portalName
encodeClientMessage (DescribeStatement stmtName)
encodeClientMessage (DescribeStatement (StatementName stmtName))
= prependHeader 'D' $ char8 'S' <> pgString stmtName
encodeClientMessage (DescribePortal portalName)
encodeClientMessage (DescribePortal (PortalName portalName))
= prependHeader 'D' $ char8 'P' <> pgString portalName
encodeClientMessage (Execute portalName)
encodeClientMessage (Execute (PortalName portalName))
= prependHeader 'E' $
pgString portalName <>
--Maximum number of rows to return, if portal contains a query that
@ -58,15 +61,15 @@ encodeClientMessage (Execute portalName)
int32BE 0
encodeClientMessage Flush
= prependHeader 'H' mempty
encodeClientMessage (Parse stmtName stmt oids)
encodeClientMessage (Parse (StatementName stmtName) (StatementSQL stmt) oids)
= prependHeader 'P' $
pgString stmtName <>
pgString stmt <>
int16BE (fromIntegral $ V.length oids) <>
fold (int32BE <$> oids)
encodeClientMessage (PasswordMessage passText)
fold (int32BE . unOid <$> oids)
encodeClientMessage (PasswordMessage (PasswordText passText))
= prependHeader 'p' $ pgString passText
encodeClientMessage (Query stmt)
encodeClientMessage (SimpleQuery (StatementSQL stmt))
= prependHeader 'Q' $ pgString stmt
encodeClientMessage Sync
= prependHeader 'S' mempty

View File

@ -22,7 +22,7 @@ storageStatement :: StatementStorage -> StatementSQL -> IO StatementName
storageStatement (StatementStorage table counter) stmt = do
n <- readIORef counter
writeIORef counter $ n + 1
let name = pack $ show n
H.insert table name stmt
let name = StatementName . pack $ show n
H.insert table stmt name
pure name

View File

@ -1,14 +1,15 @@
module Database.PostgreSQL.Protocol.Types where
import Data.Word (Word32, Word8)
import Data.Int (Int32)
import Data.Int (Int32, Int16)
import Data.Hashable (Hashable)
import qualified Data.ByteString as B
import qualified Data.Vector as V
-- Common
newtype Oid = Oid Int32 deriving (Show)
newtype Oid = Oid { unOid :: Int32 } deriving (Show)
newtype StatementName = StatementName B.ByteString deriving (Show)
newtype StatementSQL = StatementSQL B.ByteString deriving (Show)
newtype StatementSQL = StatementSQL B.ByteString deriving (Show, Eq, Hashable)
newtype PortalName = PortalName B.ByteString deriving (Show)
newtype ChannelName = ChannelName B.ByteString deriving (Show)
@ -21,7 +22,7 @@ newtype MD5Salt = MD5Salt Word32 deriving (Show)
newtype ServerProccessId = ServerProcessId Int32 deriving (Show)
newtype ServerSecretKey = ServerSecrecKey Int32 deriving (Show)
newtype RowsCount = RowsCount Word
newtype RowsCount = RowsCount Word deriving (Show)
-- | Information about completed command.
data CommandResult
@ -34,6 +35,7 @@ data CommandResult
| MoveCompleted RowsCount
| FetchCompleted RowsCount
| CopyCompleted RowsCount
deriving (Show)
-- | Parameters of the current connection.
-- We store only the parameters that cannot change after startup.
@ -105,7 +107,7 @@ data ServerMessage
= BackendKeyData ServerProccessId ServerSecretKey
| BindComplete
| CloseComplete
| CommandComplete CommandTag
| CommandComplete CommandResult
| DataRow (V.Vector B.ByteString) -- the values of a result
| EmptyQueryResponse
| ErrorResponse ErrorDesc
@ -175,7 +177,7 @@ data ErrorDesc = ErrorDesc
, errorDataType :: Maybe B.ByteString
, errorConstraint :: Maybe B.ByteString
, errorSourceFilename :: Maybe B.ByteString
, errorSourceLine :: Maybe B.Int
, errorSourceLine :: Maybe Int
, errorRoutine :: Maybe B.ByteString
} deriving (Show)
@ -195,7 +197,7 @@ data NoticeDesc = NoticeDesc
, noticeDataType :: Maybe B.ByteString
, noticeConstraint :: Maybe B.ByteString
, noticeSourceFilename :: Maybe B.ByteString
, noticeSourceLine :: Maybe B.Int
, noticeSourceLine :: Maybe Int
, noticeRoutine :: Maybe B.ByteString
} deriving (Show)