Refactor Protoc directory

This commit is contained in:
VyacheslavHashov 2017-07-13 16:20:23 +03:00
parent 1ba93bd6c4
commit 1b89885392
18 changed files with 236 additions and 139 deletions

View File

@ -37,7 +37,6 @@ library
, Database.PostgreSQL.Protocol.Codecs.PgTypes
, Database.PostgreSQL.Protocol.Codecs.Time
, Database.PostgreSQL.Protocol.Codecs.Numeric
other-modules: Database.PostgreSQL.Protocol.Utils
build-depends: base >= 4.7 && < 5
, bytestring
, socket

View File

@ -23,33 +23,34 @@ module Database.PostgreSQL.Driver.Connection
, defaultFilter
) where
import Data.Monoid ((<>))
import Control.Monad (void, when)
import Control.Concurrent (forkIOWithUnmask, killThread, ThreadId, threadDelay
, mkWeakThreadId)
import Data.Monoid ((<>))
import Control.Concurrent (forkIOWithUnmask, killThread, ThreadId,
threadDelay , mkWeakThreadId)
import Control.Concurrent.STM (atomically)
import Control.Concurrent.STM.TQueue (TQueue, writeTQueue, newTQueueIO)
import Control.Concurrent.STM (atomically)
import Control.Exception (SomeException, bracketOnError, catch, mask_,
catch, throwIO)
import GHC.Conc (labelThread)
import Crypto.Hash (hash, Digest, MD5)
import System.Mem.Weak (Weak, deRefWeak)
import System.Socket (eBadFileDescriptor)
import Control.Exception (SomeException, bracketOnError, catch,
mask_, catch, throwIO)
import Control.Monad (void, when)
import GHC.Conc (labelThread)
import System.Mem.Weak (Weak, deRefWeak)
import Crypto.Hash (hash, Digest, MD5)
import System.Socket (eBadFileDescriptor)
import qualified Data.HashMap.Strict as HM
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as BS(pack, unpack)
import Database.PostgreSQL.Protocol.DataRows
import Database.PostgreSQL.Protocol.Encoders
import Database.PostgreSQL.Protocol.Decoders
import Database.PostgreSQL.Protocol.Parsers
import Database.PostgreSQL.Protocol.DataRows
import Database.PostgreSQL.Protocol.Types
import Database.PostgreSQL.Protocol.Store.Encode (runEncode, Encode)
import Database.PostgreSQL.Protocol.Store.Decode (runDecode)
import Database.PostgreSQL.Driver.Error
import Database.PostgreSQL.Driver.Settings
import Database.PostgreSQL.Driver.StatementStorage
import Database.PostgreSQL.Driver.Error
import Database.PostgreSQL.Driver.RawConnection
-- | Public

View File

@ -17,9 +17,10 @@ module Database.PostgreSQL.Driver.Error
, throwAuthErrorInIO
) where
import Control.Exception (throwIO, Exception(..), SomeException)
import Data.ByteString (ByteString)
import System.Socket (AddressInfoException)
import Control.Exception (throwIO, Exception(..), SomeException)
import Data.ByteString (ByteString)
import System.Socket (AddressInfoException)
import qualified Data.ByteString.Char8 as BS
import Database.PostgreSQL.Protocol.Types (ErrorDesc)

View File

@ -4,10 +4,11 @@ module Database.PostgreSQL.Driver.RawConnection
, createRawConnection
) where
import Control.Monad (void, when)
import Control.Exception (bracketOnError, try)
import Data.Monoid ((<>))
import Foreign (castPtr, plusPtr)
import Control.Monad (void, when)
import Control.Exception (bracketOnError, try)
import Data.Monoid ((<>))
import Foreign (castPtr, plusPtr)
import System.Socket (socket, AddressInfo(..), getAddressInfo, socketAddress,
aiV4Mapped, AddressInfoException, Socket, connect,
close, receive, send)

View File

@ -4,8 +4,8 @@ module Database.PostgreSQL.Driver.Settings
, defaultConnectionSettings
) where
import Data.Word (Word16)
import Data.ByteString (ByteString)
import Data.Word (Word16)
import Data.ByteString (ByteString)
data TlsMode = RequiredTls | NoTls
deriving (Show, Eq)
@ -26,6 +26,7 @@ data ConnectionSettings = ConnectionSettings
, settingsTls :: TlsMode
} deriving (Show)
-- TODO change defaults
defaultConnectionSettings :: ConnectionSettings
defaultConnectionSettings = ConnectionSettings
{ settingsHost = "localhost"

View File

@ -1,18 +1,37 @@
module Database.PostgreSQL.Protocol.Codecs.Decoders where
module Database.PostgreSQL.Protocol.Codecs.Decoders
( dataRowHeader
, getNonNullable
, getNullable
, FieldDecoder
, bool
, bytea
, char
, date
, float4
, float8
, int2
, int4
, int8
, interval
, bsJsonText
, bsJsonBytes
, numeric
, bsText
, timestamp
, timestamptz
, uuid
) where
import Data.Word
import Data.Int
import Data.Maybe
import Data.Char
import Data.Scientific
import Data.UUID (UUID, fromWords)
import Data.Time (Day, UTCTime, LocalTime, DiffTime)
import qualified Data.ByteString as B
import Prelude hiding (bool)
import Control.Monad (replicateM, (<$!>))
import Data.ByteString (ByteString)
import Data.Char (chr)
import Data.Int (Int16, Int32, Int64)
import Data.Scientific (Scientific)
import Data.Time (Day, UTCTime, LocalTime, DiffTime)
import Data.UUID (UUID, fromWords)
import qualified Data.Vector as V
import Control.Monad
import Prelude hiding (bool)
import Database.PostgreSQL.Protocol.Store.Decode
import Database.PostgreSQL.Protocol.Types
import Database.PostgreSQL.Protocol.Codecs.Time
@ -85,7 +104,7 @@ bool :: FieldDecoder Bool
bool _ = (== 1) <$> getWord8
{-# INLINE bytea #-}
bytea :: FieldDecoder B.ByteString
bytea :: FieldDecoder ByteString
bytea = getByteString
{-# INLINE char #-}
@ -122,12 +141,12 @@ interval _ = intervalToDiffTime <$> getInt64BE <*> getInt32BE <*> getInt32BE
-- | Decodes representation of JSON as @ByteString@.
{-# INLINE bsJsonText #-}
bsJsonText :: FieldDecoder B.ByteString
bsJsonText :: FieldDecoder ByteString
bsJsonText = getByteString
-- | Decodes representation of JSONB as @ByteString@.
{-# INLINE bsJsonBytes #-}
bsJsonBytes :: FieldDecoder B.ByteString
bsJsonBytes :: FieldDecoder ByteString
bsJsonBytes len = getWord8 *> getByteString (len - 1)
{-# INLINE numeric #-}
@ -142,7 +161,7 @@ numeric _ = do
-- | Decodes text without applying encoding.
{-# INLINE bsText #-}
bsText :: FieldDecoder B.ByteString
bsText :: FieldDecoder ByteString
bsText = getByteString
{-# INLINE timestamp #-}

View File

@ -1,16 +1,30 @@
module Database.PostgreSQL.Protocol.Codecs.Encoders where
module Database.PostgreSQL.Protocol.Codecs.Encoders
( bool
, bytea
, char
, date
, float4
, float8
, int2
, int4
, int8
, interval
, bsJsonText
, bsJsonBytes
, numeric
, bsText
, timestamp
, timestamptz
, uuid
) where
import Data.Word
import Data.Monoid ((<>))
import Data.Int
import Data.Char
import Data.Scientific
import Data.UUID (UUID, toWords)
import Data.Time (Day, UTCTime, LocalTime, DiffTime)
import qualified Data.ByteString as B
import qualified Data.Vector as V
import Control.Monad
import Data.ByteString (ByteString)
import Data.Char (ord)
import Data.Int (Int16, Int32, Int64)
import Data.Monoid ((<>))
import Data.Scientific (Scientific)
import Data.Time (Day, UTCTime, LocalTime, DiffTime)
import Data.UUID (UUID, toWords)
import Database.PostgreSQL.Protocol.Store.Encode
import Database.PostgreSQL.Protocol.Types
@ -27,7 +41,7 @@ bool False = putWord8 0
bool True = putWord8 1
{-# INLINE bytea #-}
bytea :: B.ByteString -> Encode
bytea :: ByteString -> Encode
bytea = putByteString
{-# INLINE char #-}
@ -65,12 +79,12 @@ interval v = let (mcs, days, months) = diffTimeToInterval v
-- | Encodes representation of JSON as @ByteString@.
{-# INLINE bsJsonText #-}
bsJsonText :: B.ByteString -> Encode
bsJsonText :: ByteString -> Encode
bsJsonText = putByteString
-- | Encodes representation of JSONB as @ByteString@.
{-# INLINE bsJsonBytes #-}
bsJsonBytes :: B.ByteString -> Encode
bsJsonBytes :: ByteString -> Encode
bsJsonBytes bs = putWord8 1 <> putByteString bs
{-# INLINE numeric #-}
@ -85,7 +99,7 @@ numeric n =
-- | Encodes text.
{-# INLINE bsText #-}
bsText :: B.ByteString -> Encode
bsText :: ByteString -> Encode
bsText = putByteString
{-# INLINE timestamp #-}

View File

@ -1,12 +1,17 @@
{-# language LambdaCase #-}
module Database.PostgreSQL.Protocol.Codecs.Numeric where
module Database.PostgreSQL.Protocol.Codecs.Numeric
( scientificToNumeric
, numericToScientific
, toNumericSign
, fromNumericSign
) where
import Data.Word (Word16)
import Data.Int (Int16)
import Data.Foldable (foldl')
import Data.Scientific (Scientific, scientific, base10Exponent, coefficient)
import Data.Int (Int16)
import Data.List (unfoldr)
import Data.Scientific (Scientific, scientific, base10Exponent, coefficient)
import Data.Word (Word16)
{-# INLINE scientificToNumeric #-}
scientificToNumeric :: Scientific -> (Word16, Int16, Word16, [Word16])

View File

@ -1,7 +1,34 @@
{-
Oids for built-in types.
-}
module Database.PostgreSQL.Protocol.Codecs.PgTypes where
module Database.PostgreSQL.Protocol.Codecs.PgTypes
( Oids(..)
-- * Primitives
, bool
, bytea
, char
, date
, float4
, float8
, int2
, int4
, int8
, interval
, json
, jsonb
, numeric
, text
, timestamp
, timestamptz
, uuid
-- * Ranges
, daterange
, int4range
, int8range
, numrange
, tsrange
, tstzrange
) where
import Data.Word (Word32)

View File

@ -50,7 +50,6 @@ intervalToDiffTime mcs days months = picosecondsToDiffTime . mcsToPcs $
microsInDay * (fromIntegral months * daysInMonth + fromIntegral days)
+ fromIntegral mcs
-- TODO consider adjusted encoding
{-# INLINE diffTimeToInterval #-}
diffTimeToInterval :: DiffTime -> (Int64, Int32, Int32)
diffTimeToInterval dt = (fromIntegral $ diffTimeToMcs dt, 0, 0)

View File

@ -1,3 +1,4 @@
{-# language ForeignFunctionInterface #-}
module Database.PostgreSQL.Protocol.DataRows
( loopExtractDataRows
, countDataRows
@ -6,22 +7,24 @@ module Database.PostgreSQL.Protocol.DataRows
, decodeOneRow
) where
import Data.Monoid ((<>))
import Data.Word (Word8, byteSwap32)
import Foreign (peek, peekByteOff, castPtr)
import Data.Foldable (traverse_)
import Data.Monoid ((<>))
import Data.Word (Word8, byteSwap32)
import Foreign (Ptr, alloca, peek, peekByteOff, castPtr)
import Foreign.C.Types (CInt, CSize(..), CChar, CULong)
import Foreign (Ptr, peek, alloca)
import System.IO.Unsafe (unsafePerformIO)
import qualified Data.ByteString as B
import qualified Data.ByteString.Unsafe as B
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as MV
import qualified Data.List as L
import Data.Foldable
import System.IO.Unsafe
import Database.PostgreSQL.Driver.Error
import Database.PostgreSQL.Protocol.Types
import Database.PostgreSQL.Protocol.Parsers
import Database.PostgreSQL.Protocol.Store.Decode
import Database.PostgreSQL.Protocol.Utils
-- Optimized loop for extracting chunks of DataRows.
-- Ignores all messages from database that do not relate to data.
@ -188,3 +191,29 @@ countDataRows = foldlDataRows (\acc (DataChunk c _) -> acc + c) 0
{-# INLINE flattenDataRows #-}
flattenDataRows :: DataRows -> B.ByteString
flattenDataRows = foldlDataRows (\acc (DataChunk _ bs) -> acc <> bs) ""
--
-- C utils
--
data ScanRowResult = ScanRowResult
{-# UNPACK #-} !DataChunk -- chunk of datarows, may be empty
{-# UNPACK #-} !B.ByteString -- the rest of string
{-# UNPACK #-} !Int -- reason code
-- | Scans `ByteString` for a chunk of `DataRow`s.
{-# INLINE scanDataRows #-}
scanDataRows :: B.ByteString -> IO ScanRowResult
scanDataRows bs =
alloca $ \countPtr ->
alloca $ \reasonPtr ->
B.unsafeUseAsCStringLen bs $ \(ptr, len) -> do
offset <- fromIntegral <$>
c_scan_datarows ptr (fromIntegral len) countPtr reasonPtr
reason <- fromIntegral <$> peek reasonPtr
count <- fromIntegral <$> peek countPtr
let (ch, rest) = B.splitAt offset bs
pure $ ScanRowResult (DataChunk count ch) rest reason
foreign import ccall unsafe "static pw_utils.h scan_datarows" c_scan_datarows
:: Ptr CChar -> CSize -> Ptr CULong -> Ptr CInt -> IO CSize

View File

@ -8,10 +8,11 @@ module Database.PostgreSQL.Protocol.Decoders
, decodeServerMessage
) where
import Data.Char (chr)
import qualified Data.Vector as V
import Data.Char (chr)
import Data.ByteString.Char8 as BS(unpack)
import qualified Data.ByteString as B
import Data.ByteString.Char8 as BS(unpack)
import qualified Data.Vector as V
import Database.PostgreSQL.Protocol.Types
import Database.PostgreSQL.Protocol.Store.Decode

View File

@ -3,9 +3,10 @@ module Database.PostgreSQL.Protocol.Encoders
, encodeClientMessage
) where
import Data.Word (Word32)
import Data.Monoid ((<>))
import Data.Char (ord)
import Data.Char (ord)
import Data.Monoid ((<>))
import Data.Word (Word32)
import qualified Data.ByteString as B
import Database.PostgreSQL.Protocol.Types

View File

@ -1,5 +1,6 @@
{-# language RecordWildCards #-}
-- TODO doc
-- Helper parser that works with ByteString,
-- not Decode
module Database.PostgreSQL.Protocol.Parsers
@ -10,12 +11,13 @@ module Database.PostgreSQL.Protocol.Parsers
, parseCommandResult
) where
import Data.Monoid ((<>))
import Data.Char (chr)
import Data.Maybe (fromMaybe)
import Data.Char (chr)
import Data.Maybe (fromMaybe)
import Data.Monoid ((<>))
import Text.Read (readMaybe)
import Data.ByteString.Char8 as BS (readInteger, readInt, unpack, pack)
import qualified Data.ByteString as B
import Data.ByteString.Char8 as BS(readInteger, readInt, unpack, pack)
import Text.Read (readMaybe)
import qualified Data.HashMap.Strict as HM
import Database.PostgreSQL.Protocol.Types

View File

@ -1,18 +1,32 @@
module Database.PostgreSQL.Protocol.Store.Decode where
module Database.PostgreSQL.Protocol.Store.Decode
( Decode
, runDecode
, runDecodeIO
, embedIO
, skipBytes
, getByteString
, getByteStringNull
, getWord8
, getWord16BE
, getWord32BE
, getWord64BE
, getInt16BE
, getInt32BE
, getInt64BE
, getFloat32BE
, getFloat64BE
) where
import Prelude hiding (takeWhile)
import Prelude hiding (takeWhile)
import Data.Int (Int16, Int32, Int64)
import Data.Word (Word8, Word16, Word32, Word64,
byteSwap16, byteSwap32, byteSwap64)
import Foreign (Ptr, Storable, alloca, peek, poke, castPtr, plusPtr)
import Data.Store.Core (Peek(..), PeekResult(..), decodeExPortionWith,
decodeIOPortionWith)
import qualified Data.ByteString as B
import Data.Word
import Data.Int
import Data.Tuple
import Data.Store.Core
import Foreign
import Control.Monad
import Control.Applicative
import Data.ByteString (ByteString)
import qualified Data.ByteString.Internal as B
newtype Decode a = Decode (Peek a)
deriving (Functor, Applicative, Monad)
@ -42,8 +56,6 @@ prim len f = Decode $ Peek $ \ps ptr -> do
let !newPtr = ptr `plusPtr` len
pure (PeekResult newPtr v)
-- Public
{-# INLINE skipBytes #-}
skipBytes :: Int -> Decode ()
skipBytes n = prim n $ const $ pure ()

View File

@ -1,15 +1,30 @@
module Database.PostgreSQL.Protocol.Store.Encode where
module Database.PostgreSQL.Protocol.Store.Encode
( Encode
, getEncodeLen
, runEncode
, putByteString
, putByteStringNull
, putWord8
, putWord16BE
, putWord32BE
, putWord64BE
, putInt16BE
, putInt32BE
, putInt64BE
, putFloat32BE
, putFloat64BE
) where
import Data.Monoid (Monoid(..), (<>))
import Foreign (poke, plusPtr, Ptr)
import Data.Int (Int16, Int32)
import Data.Word
import Data.Monoid (Monoid(..), (<>))
import Foreign (Storable, alloca, peek, poke, castPtr, plusPtr, Ptr)
import Data.Int (Int16, Int32, Int64)
import Data.Word (Word8, Word16, Word32, Word64,
byteSwap16, byteSwap32, byteSwap64)
import Foreign
import Data.ByteString (ByteString)
import Data.ByteString.Internal as B(toForeignPtr)
import Data.Store.Core (Poke(..), unsafeEncodeWith, pokeStatePtr,
pokeFromForeignPtr)
import Data.ByteString (ByteString)
import Data.ByteString.Internal (toForeignPtr)
import Data.Store.Core (Poke(..), unsafeEncodeWith, pokeStatePtr,
pokeFromForeignPtr)
data Encode = Encode {-# UNPACK #-} !Int !(Poke ())

View File

@ -10,12 +10,14 @@ module Database.PostgreSQL.Protocol.Types where
-- * bind command can have different formats for parameters and results
-- but we assume that there will be one format for all.
import Data.Word (Word32, Word8, Word16)
import Data.Int (Int32, Int16)
import Data.Hashable (Hashable)
import Data.ByteString as B(ByteString)
import Data.Int (Int32, Int16)
import Data.Word (Word32, Word8, Word16)
import Data.ByteString as B (ByteString)
import Data.Hashable (Hashable)
import Data.Vector (Vector)
import qualified Data.ByteString.Lazy as BL(ByteString)
import Data.Vector (Vector)
import Database.PostgreSQL.Protocol.Store.Encode (Encode)
-- Common

View File

@ -1,32 +0,0 @@
{-# language ForeignFunctionInterface #-}
module Database.PostgreSQL.Protocol.Utils where
import Foreign.C.Types (CInt, CSize(..), CChar, CULong)
import Foreign (Ptr, peek, alloca)
import qualified Data.ByteString as B
import qualified Data.ByteString.Unsafe as B
import Database.PostgreSQL.Protocol.Types (DataChunk(..))
data ScanRowResult = ScanRowResult
{-# UNPACK #-} !DataChunk -- chunk of datarows, may be empty
{-# UNPACK #-} !B.ByteString -- the rest of string
{-# UNPACK #-} !Int -- reason code
-- | Scans `ByteString` for a chunk of `DataRow`s.
{-# INLINE scanDataRows #-}
scanDataRows :: B.ByteString -> IO ScanRowResult
scanDataRows bs =
alloca $ \countPtr ->
alloca $ \reasonPtr ->
B.unsafeUseAsCStringLen bs $ \(ptr, len) -> do
offset <- fromIntegral <$>
c_scan_datarows ptr (fromIntegral len) countPtr reasonPtr
reason <- fromIntegral <$> peek reasonPtr
count <- fromIntegral <$> peek countPtr
let (ch, rest) = B.splitAt offset bs
pure $ ScanRowResult (DataChunk count ch) rest reason
foreign import ccall unsafe "static pw_utils.h scan_datarows" c_scan_datarows
:: Ptr CChar -> CSize -> Ptr CULong -> Ptr CInt -> IO CSize