Fixed encoders with newtypes

This commit is contained in:
VyacheslavHashov 2017-01-18 19:57:05 +03:00
parent b9b0098ac0
commit da8bba2c74
3 changed files with 27 additions and 22 deletions

View File

@ -17,21 +17,23 @@ currentVersion = 3 * 256 * 256
encodeStartMessage :: StartMessage -> Builder encodeStartMessage :: StartMessage -> Builder
-- Options except user and database are not supported -- Options except user and database are not supported
encodeStartMessage (StartupMessage uname dbname) = encodeStartMessage (StartupMessage (Username uname) (DatabaseName dbname)) =
int32BE (len + 4) <> payload int32BE (len + 4) <> payload
where where
len = fromIntegral $ BL.length $ toLazyByteString payload len = fromIntegral $ BL.length $ toLazyByteString payload
payload = int32BE currentVersion <> payload = int32BE currentVersion <>
pgString "user" <> pgString uname <> pgString "user" <> pgString uname <>
pgString "database" <> pgString dbname <> word8 0 pgString "database" <> pgString dbname <> word8 0
-- TODO
encodeStartMessage SSLRequest = undefined encodeStartMessage SSLRequest = undefined
encodeClientMessage :: ClientMessage -> Builder encodeClientMessage :: ClientMessage -> Builder
encodeClientMessage (Bind portalName stmtName paramFormat values resultFormat) encodeClientMessage (Bind (PortalName portalName) (StatementName stmtName)
paramFormat values resultFormat)
= prependHeader 'B' $ = prependHeader 'B' $
pgString portalName <> pgString portalName <>
pgString stmtName <> pgString stmtName <>
-- the specified format code is applied to all parameters -- `1` means that the specified format code is applied to all parameters
int16BE 1 <> int16BE 1 <>
encodeFormat paramFormat <> encodeFormat paramFormat <>
int16BE (fromIntegral $ V.length values) <> int16BE (fromIntegral $ V.length values) <>
@ -39,18 +41,19 @@ encodeClientMessage (Bind portalName stmtName paramFormat values resultFormat)
-- follow in the NULL case. -- follow in the NULL case.
fold ((\v -> int32BE (fromIntegral $ B.length v) <> byteString v) fold ((\v -> int32BE (fromIntegral $ B.length v) <> byteString v)
<$> values) <> <$> values) <>
-- the specified format code is applied to all result columns (if any) -- `1` means that the specified format code is applied to all
-- result columns (if any)
int16BE 1 <> int16BE 1 <>
encodeFormat resultFormat encodeFormat resultFormat
encodeClientMessage (CloseStatement stmtName) encodeClientMessage (CloseStatement (StatementName stmtName))
= prependHeader 'C' $ char8 'S' <> pgString stmtName = prependHeader 'C' $ char8 'S' <> pgString stmtName
encodeClientMessage (ClosePortal portalName) encodeClientMessage (ClosePortal (PortalName portalName))
= prependHeader 'C' $ char8 'P' <> pgString portalName = prependHeader 'C' $ char8 'P' <> pgString portalName
encodeClientMessage (DescribeStatement stmtName) encodeClientMessage (DescribeStatement (StatementName stmtName))
= prependHeader 'D' $ char8 'S' <> pgString stmtName = prependHeader 'D' $ char8 'S' <> pgString stmtName
encodeClientMessage (DescribePortal portalName) encodeClientMessage (DescribePortal (PortalName portalName))
= prependHeader 'D' $ char8 'P' <> pgString portalName = prependHeader 'D' $ char8 'P' <> pgString portalName
encodeClientMessage (Execute portalName) encodeClientMessage (Execute (PortalName portalName))
= prependHeader 'E' $ = prependHeader 'E' $
pgString portalName <> pgString portalName <>
--Maximum number of rows to return, if portal contains a query that --Maximum number of rows to return, if portal contains a query that
@ -58,15 +61,15 @@ encodeClientMessage (Execute portalName)
int32BE 0 int32BE 0
encodeClientMessage Flush encodeClientMessage Flush
= prependHeader 'H' mempty = prependHeader 'H' mempty
encodeClientMessage (Parse stmtName stmt oids) encodeClientMessage (Parse (StatementName stmtName) (StatementSQL stmt) oids)
= prependHeader 'P' $ = prependHeader 'P' $
pgString stmtName <> pgString stmtName <>
pgString stmt <> pgString stmt <>
int16BE (fromIntegral $ V.length oids) <> int16BE (fromIntegral $ V.length oids) <>
fold (int32BE <$> oids) fold (int32BE . unOid <$> oids)
encodeClientMessage (PasswordMessage passText) encodeClientMessage (PasswordMessage (PasswordText passText))
= prependHeader 'p' $ pgString passText = prependHeader 'p' $ pgString passText
encodeClientMessage (Query stmt) encodeClientMessage (SimpleQuery (StatementSQL stmt))
= prependHeader 'Q' $ pgString stmt = prependHeader 'Q' $ pgString stmt
encodeClientMessage Sync encodeClientMessage Sync
= prependHeader 'S' mempty = prependHeader 'S' mempty

View File

@ -22,7 +22,7 @@ storageStatement :: StatementStorage -> StatementSQL -> IO StatementName
storageStatement (StatementStorage table counter) stmt = do storageStatement (StatementStorage table counter) stmt = do
n <- readIORef counter n <- readIORef counter
writeIORef counter $ n + 1 writeIORef counter $ n + 1
let name = pack $ show n let name = StatementName . pack $ show n
H.insert table name stmt H.insert table stmt name
pure name pure name

View File

@ -1,14 +1,15 @@
module Database.PostgreSQL.Protocol.Types where module Database.PostgreSQL.Protocol.Types where
import Data.Word (Word32, Word8) import Data.Word (Word32, Word8)
import Data.Int (Int32) import Data.Int (Int32, Int16)
import Data.Hashable (Hashable)
import qualified Data.ByteString as B import qualified Data.ByteString as B
import qualified Data.Vector as V import qualified Data.Vector as V
-- Common -- Common
newtype Oid = Oid Int32 deriving (Show) newtype Oid = Oid { unOid :: Int32 } deriving (Show)
newtype StatementName = StatementName B.ByteString deriving (Show) newtype StatementName = StatementName B.ByteString deriving (Show)
newtype StatementSQL = StatementSQL B.ByteString deriving (Show) newtype StatementSQL = StatementSQL B.ByteString deriving (Show, Eq, Hashable)
newtype PortalName = PortalName B.ByteString deriving (Show) newtype PortalName = PortalName B.ByteString deriving (Show)
newtype ChannelName = ChannelName B.ByteString deriving (Show) newtype ChannelName = ChannelName B.ByteString deriving (Show)
@ -21,7 +22,7 @@ newtype MD5Salt = MD5Salt Word32 deriving (Show)
newtype ServerProccessId = ServerProcessId Int32 deriving (Show) newtype ServerProccessId = ServerProcessId Int32 deriving (Show)
newtype ServerSecretKey = ServerSecrecKey Int32 deriving (Show) newtype ServerSecretKey = ServerSecrecKey Int32 deriving (Show)
newtype RowsCount = RowsCount Word newtype RowsCount = RowsCount Word deriving (Show)
-- | Information about completed command. -- | Information about completed command.
data CommandResult data CommandResult
@ -34,6 +35,7 @@ data CommandResult
| MoveCompleted RowsCount | MoveCompleted RowsCount
| FetchCompleted RowsCount | FetchCompleted RowsCount
| CopyCompleted RowsCount | CopyCompleted RowsCount
deriving (Show)
-- | Parameters of the current connection. -- | Parameters of the current connection.
-- We store only the parameters that cannot change after startup. -- We store only the parameters that cannot change after startup.
@ -105,7 +107,7 @@ data ServerMessage
= BackendKeyData ServerProccessId ServerSecretKey = BackendKeyData ServerProccessId ServerSecretKey
| BindComplete | BindComplete
| CloseComplete | CloseComplete
| CommandComplete CommandTag | CommandComplete CommandResult
| DataRow (V.Vector B.ByteString) -- the values of a result | DataRow (V.Vector B.ByteString) -- the values of a result
| EmptyQueryResponse | EmptyQueryResponse
| ErrorResponse ErrorDesc | ErrorResponse ErrorDesc
@ -175,7 +177,7 @@ data ErrorDesc = ErrorDesc
, errorDataType :: Maybe B.ByteString , errorDataType :: Maybe B.ByteString
, errorConstraint :: Maybe B.ByteString , errorConstraint :: Maybe B.ByteString
, errorSourceFilename :: Maybe B.ByteString , errorSourceFilename :: Maybe B.ByteString
, errorSourceLine :: Maybe B.Int , errorSourceLine :: Maybe Int
, errorRoutine :: Maybe B.ByteString , errorRoutine :: Maybe B.ByteString
} deriving (Show) } deriving (Show)
@ -195,7 +197,7 @@ data NoticeDesc = NoticeDesc
, noticeDataType :: Maybe B.ByteString , noticeDataType :: Maybe B.ByteString
, noticeConstraint :: Maybe B.ByteString , noticeConstraint :: Maybe B.ByteString
, noticeSourceFilename :: Maybe B.ByteString , noticeSourceFilename :: Maybe B.ByteString
, noticeSourceLine :: Maybe B.Int , noticeSourceLine :: Maybe Int
, noticeRoutine :: Maybe B.ByteString , noticeRoutine :: Maybe B.ByteString
} deriving (Show) } deriving (Show)