Refactor Driver directory

This commit is contained in:
VyacheslavHashov 2017-07-13 17:04:30 +03:00
parent 1b89885392
commit 856616005e
5 changed files with 70 additions and 48 deletions

View File

@ -58,6 +58,7 @@ library
BangPatterns
OverloadedStrings
GeneralizedNewtypeDeriving
LambdaCase
cc-options: -O2 -Wall
test-suite postgres-wire-test-connection

View File

@ -56,11 +56,11 @@ import Database.PostgreSQL.Driver.RawConnection
-- | Public
-- Connection parametrized by message type in chan.
data AbsConnection mt = AbsConnection
{ connRawConnection :: RawConnection
, connReceiverThread :: Weak ThreadId
, connStatementStorage :: StatementStorage
, connParameters :: ConnectionParameters
, connOutChan :: TQueue (Either ReceiverException mt)
{ connRawConnection :: !RawConnection
, connReceiverThread :: !(Weak ThreadId)
, connStatementStorage :: !StatementStorage
, connParameters :: !ConnectionParameters
, connOutChan :: !(TQueue (Either ReceiverException mt))
}
type Connection = AbsConnection DataMessage
@ -122,15 +122,18 @@ connectCommon' settings msgFilter = connectWith settings $ \rawConn params ->
-- Low-level sending functions
{-# INLINE sendStartMessage #-}
sendStartMessage :: RawConnection -> StartMessage -> IO ()
sendStartMessage rawConn msg = void $
rSend rawConn . runEncode $ encodeStartMessage msg
-- Only for testings and simple queries
{-# INLINE sendMessage #-}
sendMessage :: RawConnection -> ClientMessage -> IO ()
sendMessage rawConn msg = void $
rSend rawConn . runEncode $ encodeClientMessage msg
{-# INLINE sendEncode #-}
sendEncode :: AbsConnection c -> Encode -> IO ()
sendEncode conn = void . rSend (connRawConnection conn) . runEncode
@ -290,6 +293,11 @@ receiverThreadCommon rawConn chan msgFilter ntfHandler = go ""
dispatchIfNotification (NotificationResponse ntf) handler = handler ntf
dispatchIfNotification _ _ = pure ()
-- | Helper to read from queue.
{-# INLINE writeChan #-}
writeChan :: TQueue a -> a -> IO ()
writeChan q = atomically . writeTQueue q
defaultNotificationHandler :: NotificationHandler
defaultNotificationHandler = const $ pure ()
@ -332,7 +340,3 @@ defaultFilter msg = case msg of
-- as result for `describe` message
RowDescription{} -> True
-- | Helper to read from queue.
writeChan :: TQueue a -> a -> IO ()
writeChan q = atomically . writeTQueue q

View File

@ -12,13 +12,12 @@ module Database.PostgreSQL.Driver.Query
, collectUntilReadyForQuery
) where
import Data.Foldable
import Data.Monoid
import Data.Bifunctor
import qualified Data.Vector as V
import qualified Data.ByteString as B
import Control.Concurrent.STM.TQueue (TQueue, readTQueue )
import Control.Concurrent.STM (atomically)
import Control.Concurrent.STM (atomically)
import Data.Foldable (fold)
import Data.Monoid ((<>))
import Data.ByteString (ByteString)
import Data.Vector (Vector)
import Database.PostgreSQL.Protocol.Encoders
import Database.PostgreSQL.Protocol.Store.Encode
@ -31,26 +30,30 @@ import Database.PostgreSQL.Driver.StatementStorage
-- Public
data Query = Query
{ qStatement :: B.ByteString
, qValues :: [(Oid, Maybe Encode)]
, qParamsFormat :: Format
, qResultFormat :: Format
, qCachePolicy :: CachePolicy
{ qStatement :: !ByteString
, qValues :: ![(Oid, Maybe Encode)]
, qParamsFormat :: !Format
, qResultFormat :: !Format
, qCachePolicy :: !CachePolicy
} deriving (Show)
-- | Public
{- INLINE sendBatchAndFlush #-}
sendBatchAndFlush :: Connection -> [Query] -> IO ()
sendBatchAndFlush = sendBatchEndBy Flush
-- | Public
{-# INLINE sendBatchAndSync #-}
sendBatchAndSync :: Connection -> [Query] -> IO ()
sendBatchAndSync = sendBatchEndBy Sync
-- | Public
{-# INLINE sendSync #-}
sendSync :: Connection -> IO ()
sendSync conn = sendEncode conn $ encodeClientMessage Sync
-- | Public
{-# INLINABLE readNextData #-}
readNextData :: Connection -> IO (Either Error DataRows)
readNextData conn =
readChan (connOutChan conn) >>=
@ -62,6 +65,7 @@ readNextData conn =
DataReady -> throwIncorrectUsage
"Expected DataRow message, but got ReadyForQuery"
{-# INLINABLE waitReadyForQuery #-}
waitReadyForQuery :: Connection -> IO (Either Error ())
waitReadyForQuery conn =
readChan (connOutChan conn) >>=
@ -77,6 +81,7 @@ waitReadyForQuery conn =
DataReady -> pure $ Right ()
-- Helper
{-# INLINE sendBatchEndBy #-}
sendBatchEndBy :: ClientMessage -> Connection -> [Query] -> IO ()
sendBatchEndBy msg conn qs = do
batch <- constructBatch conn qs
@ -90,28 +95,27 @@ constructBatch conn = fmap fold . traverse constructSingle
pname = PortalName ""
constructSingle q = do
let stmtSQL = StatementSQL $ qStatement q
(sname, parseMessage) <- case qCachePolicy q of
AlwaysCache -> do
mName <- lookupStatement storage stmtSQL
case mName of
Nothing -> do
newName <- storeStatement storage stmtSQL
pure (newName, encodeClientMessage $
Parse newName stmtSQL (fst <$> qValues q))
Just name -> pure (name, mempty)
NeverCache -> do
let newName = defaultStatementName
pure (newName, encodeClientMessage $
Parse newName stmtSQL (fst <$> qValues q))
let bindMessage = encodeClientMessage $
Bind pname sname (qParamsFormat q) (snd <$> qValues q)
(stmtName, needParse) <- case qCachePolicy q of
AlwaysCache -> lookupStatement storage stmtSQL >>= \case
Nothing -> do
newName <- storeStatement storage stmtSQL
pure (newName, True)
Just name ->
pure (name, False)
NeverCache -> pure (defaultStatementName, True)
let parseMessage = if needParse
then encodeClientMessage $
Parse stmtName stmtSQL (fst <$> qValues q)
else mempty
bindMessage = encodeClientMessage $
Bind pname stmtName (qParamsFormat q) (snd <$> qValues q)
(qResultFormat q)
executeMessage = encodeClientMessage $
Execute pname noLimitToReceive
pure $ parseMessage <> bindMessage <> executeMessage
-- | Public
sendSimpleQuery :: ConnectionCommon -> B.ByteString -> IO (Either Error ())
sendSimpleQuery :: ConnectionCommon -> ByteString -> IO (Either Error ())
sendSimpleQuery conn q = do
sendMessage (connRawConnection conn) $ SimpleQuery (StatementSQL q)
(checkErrors =<<) <$> collectUntilReadyForQuery conn
@ -122,8 +126,8 @@ sendSimpleQuery conn q = do
-- | Public
describeStatement
:: ConnectionCommon
-> B.ByteString
-> IO (Either Error (V.Vector Oid, V.Vector FieldDescription))
-> ByteString
-> IO (Either Error (Vector Oid, Vector FieldDescription))
describeStatement conn stmt = do
sendEncode conn $
encodeClientMessage (Parse sname (StatementSQL stmt) [])
@ -135,7 +139,7 @@ describeStatement conn stmt = do
sname = StatementName ""
parseMessages msgs = case msgs of
[ParameterDescription params, NoData]
-> pure $ Right (params, V.empty)
-> pure $ Right (params, mempty)
[ParameterDescription params, RowDescription fields]
-> pure $ Right (params, fields)
xs -> maybe
@ -160,5 +164,6 @@ findFirstError [] = Nothing
findFirstError (ErrorResponse desc : _) = Just desc
findFirstError (_ : xs) = findFirstError xs
{-# INLINE readChan #-}
readChan :: TQueue a -> IO a
readChan = atomically . readTQueue

View File

@ -1,10 +1,20 @@
module Database.PostgreSQL.Driver.StatementStorage where
module Database.PostgreSQL.Driver.StatementStorage
( StatementStorage
, CachePolicy(..)
, newStatementStorage
, lookupStatement
, storeStatement
, getCacheSize
, defaultStatementName
) where
import qualified Data.HashTable.IO as H
import qualified Data.ByteString as B
import Data.Monoid ((<>))
import Data.IORef (IORef, newIORef, readIORef, writeIORef)
import Data.Word (Word)
import Data.ByteString (ByteString)
import Data.ByteString.Char8 (pack)
import Data.Word (Word)
import Data.IORef
import qualified Data.HashTable.IO as H
import Database.PostgreSQL.Protocol.Types
@ -21,16 +31,17 @@ data CachePolicy
newStatementStorage :: IO StatementStorage
newStatementStorage = StatementStorage <$> H.new <*> newIORef 0
{-# INLINE lookupStatement #-}
lookupStatement :: StatementStorage -> StatementSQL -> IO (Maybe StatementName)
lookupStatement (StatementStorage table _) = H.lookup table
-- TODO place right name
-- TODO info about exceptions and mask
{-# INLINE storeStatement #-}
storeStatement :: StatementStorage -> StatementSQL -> IO StatementName
storeStatement (StatementStorage table counter) stmt = do
n <- readIORef counter
writeIORef counter $ n + 1
let name = StatementName . pack $ show n
let name = StatementName . (statementPrefix <>) . pack $ show n
H.insert table stmt name
pure name
@ -40,3 +51,6 @@ getCacheSize (StatementStorage _ counter) = readIORef counter
defaultStatementName :: StatementName
defaultStatementName = StatementName ""
statementPrefix :: ByteString
statementPrefix = "_pw_statement_"

View File

@ -1,5 +1,3 @@
{-# language LambdaCase #-}
module Database.PostgreSQL.Protocol.Codecs.Numeric
( scientificToNumeric
, numericToScientific