From da8bba2c74220c74721e74877dfa55ef14e9e1ce Mon Sep 17 00:00:00 2001 From: VyacheslavHashov Date: Wed, 18 Jan 2017 19:57:05 +0300 Subject: [PATCH] Fixed encoders with newtypes --- src/Database/PostgreSQL/Protocol/Encoders.hs | 29 ++++++++++--------- .../PostgreSQL/Protocol/StatementStorage.hs | 4 +-- src/Database/PostgreSQL/Protocol/Types.hs | 16 +++++----- 3 files changed, 27 insertions(+), 22 deletions(-) diff --git a/src/Database/PostgreSQL/Protocol/Encoders.hs b/src/Database/PostgreSQL/Protocol/Encoders.hs index c711311..05240b7 100644 --- a/src/Database/PostgreSQL/Protocol/Encoders.hs +++ b/src/Database/PostgreSQL/Protocol/Encoders.hs @@ -17,21 +17,23 @@ currentVersion = 3 * 256 * 256 encodeStartMessage :: StartMessage -> Builder -- Options except user and database are not supported -encodeStartMessage (StartupMessage uname dbname) = +encodeStartMessage (StartupMessage (Username uname) (DatabaseName dbname)) = int32BE (len + 4) <> payload where len = fromIntegral $ BL.length $ toLazyByteString payload payload = int32BE currentVersion <> pgString "user" <> pgString uname <> pgString "database" <> pgString dbname <> word8 0 + -- TODO encodeStartMessage SSLRequest = undefined encodeClientMessage :: ClientMessage -> Builder -encodeClientMessage (Bind portalName stmtName paramFormat values resultFormat) +encodeClientMessage (Bind (PortalName portalName) (StatementName stmtName) + paramFormat values resultFormat) = prependHeader 'B' $ pgString portalName <> pgString stmtName <> - -- the specified format code is applied to all parameters + -- `1` means that the specified format code is applied to all parameters int16BE 1 <> encodeFormat paramFormat <> int16BE (fromIntegral $ V.length values) <> @@ -39,18 +41,19 @@ encodeClientMessage (Bind portalName stmtName paramFormat values resultFormat) -- follow in the NULL case. fold ((\v -> int32BE (fromIntegral $ B.length v) <> byteString v) <$> values) <> - -- the specified format code is applied to all result columns (if any) + -- `1` means that the specified format code is applied to all + -- result columns (if any) int16BE 1 <> encodeFormat resultFormat -encodeClientMessage (CloseStatement stmtName) +encodeClientMessage (CloseStatement (StatementName stmtName)) = prependHeader 'C' $ char8 'S' <> pgString stmtName -encodeClientMessage (ClosePortal portalName) +encodeClientMessage (ClosePortal (PortalName portalName)) = prependHeader 'C' $ char8 'P' <> pgString portalName -encodeClientMessage (DescribeStatement stmtName) +encodeClientMessage (DescribeStatement (StatementName stmtName)) = prependHeader 'D' $ char8 'S' <> pgString stmtName -encodeClientMessage (DescribePortal portalName) +encodeClientMessage (DescribePortal (PortalName portalName)) = prependHeader 'D' $ char8 'P' <> pgString portalName -encodeClientMessage (Execute portalName) +encodeClientMessage (Execute (PortalName portalName)) = prependHeader 'E' $ pgString portalName <> --Maximum number of rows to return, if portal contains a query that @@ -58,15 +61,15 @@ encodeClientMessage (Execute portalName) int32BE 0 encodeClientMessage Flush = prependHeader 'H' mempty -encodeClientMessage (Parse stmtName stmt oids) +encodeClientMessage (Parse (StatementName stmtName) (StatementSQL stmt) oids) = prependHeader 'P' $ pgString stmtName <> pgString stmt <> int16BE (fromIntegral $ V.length oids) <> - fold (int32BE <$> oids) -encodeClientMessage (PasswordMessage passText) + fold (int32BE . unOid <$> oids) +encodeClientMessage (PasswordMessage (PasswordText passText)) = prependHeader 'p' $ pgString passText -encodeClientMessage (Query stmt) +encodeClientMessage (SimpleQuery (StatementSQL stmt)) = prependHeader 'Q' $ pgString stmt encodeClientMessage Sync = prependHeader 'S' mempty diff --git a/src/Database/PostgreSQL/Protocol/StatementStorage.hs b/src/Database/PostgreSQL/Protocol/StatementStorage.hs index 9386b49..6f83229 100644 --- a/src/Database/PostgreSQL/Protocol/StatementStorage.hs +++ b/src/Database/PostgreSQL/Protocol/StatementStorage.hs @@ -22,7 +22,7 @@ storageStatement :: StatementStorage -> StatementSQL -> IO StatementName storageStatement (StatementStorage table counter) stmt = do n <- readIORef counter writeIORef counter $ n + 1 - let name = pack $ show n - H.insert table name stmt + let name = StatementName . pack $ show n + H.insert table stmt name pure name diff --git a/src/Database/PostgreSQL/Protocol/Types.hs b/src/Database/PostgreSQL/Protocol/Types.hs index a54fe47..8fd068f 100644 --- a/src/Database/PostgreSQL/Protocol/Types.hs +++ b/src/Database/PostgreSQL/Protocol/Types.hs @@ -1,14 +1,15 @@ module Database.PostgreSQL.Protocol.Types where import Data.Word (Word32, Word8) -import Data.Int (Int32) +import Data.Int (Int32, Int16) +import Data.Hashable (Hashable) import qualified Data.ByteString as B import qualified Data.Vector as V -- Common -newtype Oid = Oid Int32 deriving (Show) +newtype Oid = Oid { unOid :: Int32 } deriving (Show) newtype StatementName = StatementName B.ByteString deriving (Show) -newtype StatementSQL = StatementSQL B.ByteString deriving (Show) +newtype StatementSQL = StatementSQL B.ByteString deriving (Show, Eq, Hashable) newtype PortalName = PortalName B.ByteString deriving (Show) newtype ChannelName = ChannelName B.ByteString deriving (Show) @@ -21,7 +22,7 @@ newtype MD5Salt = MD5Salt Word32 deriving (Show) newtype ServerProccessId = ServerProcessId Int32 deriving (Show) newtype ServerSecretKey = ServerSecrecKey Int32 deriving (Show) -newtype RowsCount = RowsCount Word +newtype RowsCount = RowsCount Word deriving (Show) -- | Information about completed command. data CommandResult @@ -34,6 +35,7 @@ data CommandResult | MoveCompleted RowsCount | FetchCompleted RowsCount | CopyCompleted RowsCount + deriving (Show) -- | Parameters of the current connection. -- We store only the parameters that cannot change after startup. @@ -105,7 +107,7 @@ data ServerMessage = BackendKeyData ServerProccessId ServerSecretKey | BindComplete | CloseComplete - | CommandComplete CommandTag + | CommandComplete CommandResult | DataRow (V.Vector B.ByteString) -- the values of a result | EmptyQueryResponse | ErrorResponse ErrorDesc @@ -175,7 +177,7 @@ data ErrorDesc = ErrorDesc , errorDataType :: Maybe B.ByteString , errorConstraint :: Maybe B.ByteString , errorSourceFilename :: Maybe B.ByteString - , errorSourceLine :: Maybe B.Int + , errorSourceLine :: Maybe Int , errorRoutine :: Maybe B.ByteString } deriving (Show) @@ -195,7 +197,7 @@ data NoticeDesc = NoticeDesc , noticeDataType :: Maybe B.ByteString , noticeConstraint :: Maybe B.ByteString , noticeSourceFilename :: Maybe B.ByteString - , noticeSourceLine :: Maybe B.Int + , noticeSourceLine :: Maybe Int , noticeRoutine :: Maybe B.ByteString } deriving (Show)