mirror of
https://github.com/postgres-haskell/postgres-wire.git
synced 2024-11-21 20:20:16 +03:00
Refactor Driver directory
This commit is contained in:
parent
1b89885392
commit
856616005e
@ -58,6 +58,7 @@ library
|
||||
BangPatterns
|
||||
OverloadedStrings
|
||||
GeneralizedNewtypeDeriving
|
||||
LambdaCase
|
||||
cc-options: -O2 -Wall
|
||||
|
||||
test-suite postgres-wire-test-connection
|
||||
|
@ -56,11 +56,11 @@ import Database.PostgreSQL.Driver.RawConnection
|
||||
-- | Public
|
||||
-- Connection parametrized by message type in chan.
|
||||
data AbsConnection mt = AbsConnection
|
||||
{ connRawConnection :: RawConnection
|
||||
, connReceiverThread :: Weak ThreadId
|
||||
, connStatementStorage :: StatementStorage
|
||||
, connParameters :: ConnectionParameters
|
||||
, connOutChan :: TQueue (Either ReceiverException mt)
|
||||
{ connRawConnection :: !RawConnection
|
||||
, connReceiverThread :: !(Weak ThreadId)
|
||||
, connStatementStorage :: !StatementStorage
|
||||
, connParameters :: !ConnectionParameters
|
||||
, connOutChan :: !(TQueue (Either ReceiverException mt))
|
||||
}
|
||||
|
||||
type Connection = AbsConnection DataMessage
|
||||
@ -122,15 +122,18 @@ connectCommon' settings msgFilter = connectWith settings $ \rawConn params ->
|
||||
|
||||
-- Low-level sending functions
|
||||
|
||||
{-# INLINE sendStartMessage #-}
|
||||
sendStartMessage :: RawConnection -> StartMessage -> IO ()
|
||||
sendStartMessage rawConn msg = void $
|
||||
rSend rawConn . runEncode $ encodeStartMessage msg
|
||||
|
||||
-- Only for testings and simple queries
|
||||
{-# INLINE sendMessage #-}
|
||||
sendMessage :: RawConnection -> ClientMessage -> IO ()
|
||||
sendMessage rawConn msg = void $
|
||||
rSend rawConn . runEncode $ encodeClientMessage msg
|
||||
|
||||
{-# INLINE sendEncode #-}
|
||||
sendEncode :: AbsConnection c -> Encode -> IO ()
|
||||
sendEncode conn = void . rSend (connRawConnection conn) . runEncode
|
||||
|
||||
@ -290,6 +293,11 @@ receiverThreadCommon rawConn chan msgFilter ntfHandler = go ""
|
||||
dispatchIfNotification (NotificationResponse ntf) handler = handler ntf
|
||||
dispatchIfNotification _ _ = pure ()
|
||||
|
||||
-- | Helper to read from queue.
|
||||
{-# INLINE writeChan #-}
|
||||
writeChan :: TQueue a -> a -> IO ()
|
||||
writeChan q = atomically . writeTQueue q
|
||||
|
||||
defaultNotificationHandler :: NotificationHandler
|
||||
defaultNotificationHandler = const $ pure ()
|
||||
|
||||
@ -332,7 +340,3 @@ defaultFilter msg = case msg of
|
||||
-- as result for `describe` message
|
||||
RowDescription{} -> True
|
||||
|
||||
-- | Helper to read from queue.
|
||||
writeChan :: TQueue a -> a -> IO ()
|
||||
writeChan q = atomically . writeTQueue q
|
||||
|
||||
|
@ -12,13 +12,12 @@ module Database.PostgreSQL.Driver.Query
|
||||
, collectUntilReadyForQuery
|
||||
) where
|
||||
|
||||
import Data.Foldable
|
||||
import Data.Monoid
|
||||
import Data.Bifunctor
|
||||
import qualified Data.Vector as V
|
||||
import qualified Data.ByteString as B
|
||||
import Control.Concurrent.STM.TQueue (TQueue, readTQueue )
|
||||
import Control.Concurrent.STM (atomically)
|
||||
import Control.Concurrent.STM (atomically)
|
||||
import Data.Foldable (fold)
|
||||
import Data.Monoid ((<>))
|
||||
import Data.ByteString (ByteString)
|
||||
import Data.Vector (Vector)
|
||||
|
||||
import Database.PostgreSQL.Protocol.Encoders
|
||||
import Database.PostgreSQL.Protocol.Store.Encode
|
||||
@ -31,26 +30,30 @@ import Database.PostgreSQL.Driver.StatementStorage
|
||||
|
||||
-- Public
|
||||
data Query = Query
|
||||
{ qStatement :: B.ByteString
|
||||
, qValues :: [(Oid, Maybe Encode)]
|
||||
, qParamsFormat :: Format
|
||||
, qResultFormat :: Format
|
||||
, qCachePolicy :: CachePolicy
|
||||
{ qStatement :: !ByteString
|
||||
, qValues :: ![(Oid, Maybe Encode)]
|
||||
, qParamsFormat :: !Format
|
||||
, qResultFormat :: !Format
|
||||
, qCachePolicy :: !CachePolicy
|
||||
} deriving (Show)
|
||||
|
||||
-- | Public
|
||||
{- INLINE sendBatchAndFlush #-}
|
||||
sendBatchAndFlush :: Connection -> [Query] -> IO ()
|
||||
sendBatchAndFlush = sendBatchEndBy Flush
|
||||
|
||||
-- | Public
|
||||
{-# INLINE sendBatchAndSync #-}
|
||||
sendBatchAndSync :: Connection -> [Query] -> IO ()
|
||||
sendBatchAndSync = sendBatchEndBy Sync
|
||||
|
||||
-- | Public
|
||||
{-# INLINE sendSync #-}
|
||||
sendSync :: Connection -> IO ()
|
||||
sendSync conn = sendEncode conn $ encodeClientMessage Sync
|
||||
|
||||
-- | Public
|
||||
{-# INLINABLE readNextData #-}
|
||||
readNextData :: Connection -> IO (Either Error DataRows)
|
||||
readNextData conn =
|
||||
readChan (connOutChan conn) >>=
|
||||
@ -62,6 +65,7 @@ readNextData conn =
|
||||
DataReady -> throwIncorrectUsage
|
||||
"Expected DataRow message, but got ReadyForQuery"
|
||||
|
||||
{-# INLINABLE waitReadyForQuery #-}
|
||||
waitReadyForQuery :: Connection -> IO (Either Error ())
|
||||
waitReadyForQuery conn =
|
||||
readChan (connOutChan conn) >>=
|
||||
@ -77,6 +81,7 @@ waitReadyForQuery conn =
|
||||
DataReady -> pure $ Right ()
|
||||
|
||||
-- Helper
|
||||
{-# INLINE sendBatchEndBy #-}
|
||||
sendBatchEndBy :: ClientMessage -> Connection -> [Query] -> IO ()
|
||||
sendBatchEndBy msg conn qs = do
|
||||
batch <- constructBatch conn qs
|
||||
@ -90,28 +95,27 @@ constructBatch conn = fmap fold . traverse constructSingle
|
||||
pname = PortalName ""
|
||||
constructSingle q = do
|
||||
let stmtSQL = StatementSQL $ qStatement q
|
||||
(sname, parseMessage) <- case qCachePolicy q of
|
||||
AlwaysCache -> do
|
||||
mName <- lookupStatement storage stmtSQL
|
||||
case mName of
|
||||
Nothing -> do
|
||||
newName <- storeStatement storage stmtSQL
|
||||
pure (newName, encodeClientMessage $
|
||||
Parse newName stmtSQL (fst <$> qValues q))
|
||||
Just name -> pure (name, mempty)
|
||||
NeverCache -> do
|
||||
let newName = defaultStatementName
|
||||
pure (newName, encodeClientMessage $
|
||||
Parse newName stmtSQL (fst <$> qValues q))
|
||||
let bindMessage = encodeClientMessage $
|
||||
Bind pname sname (qParamsFormat q) (snd <$> qValues q)
|
||||
(stmtName, needParse) <- case qCachePolicy q of
|
||||
AlwaysCache -> lookupStatement storage stmtSQL >>= \case
|
||||
Nothing -> do
|
||||
newName <- storeStatement storage stmtSQL
|
||||
pure (newName, True)
|
||||
Just name ->
|
||||
pure (name, False)
|
||||
NeverCache -> pure (defaultStatementName, True)
|
||||
let parseMessage = if needParse
|
||||
then encodeClientMessage $
|
||||
Parse stmtName stmtSQL (fst <$> qValues q)
|
||||
else mempty
|
||||
bindMessage = encodeClientMessage $
|
||||
Bind pname stmtName (qParamsFormat q) (snd <$> qValues q)
|
||||
(qResultFormat q)
|
||||
executeMessage = encodeClientMessage $
|
||||
Execute pname noLimitToReceive
|
||||
pure $ parseMessage <> bindMessage <> executeMessage
|
||||
|
||||
-- | Public
|
||||
sendSimpleQuery :: ConnectionCommon -> B.ByteString -> IO (Either Error ())
|
||||
sendSimpleQuery :: ConnectionCommon -> ByteString -> IO (Either Error ())
|
||||
sendSimpleQuery conn q = do
|
||||
sendMessage (connRawConnection conn) $ SimpleQuery (StatementSQL q)
|
||||
(checkErrors =<<) <$> collectUntilReadyForQuery conn
|
||||
@ -122,8 +126,8 @@ sendSimpleQuery conn q = do
|
||||
-- | Public
|
||||
describeStatement
|
||||
:: ConnectionCommon
|
||||
-> B.ByteString
|
||||
-> IO (Either Error (V.Vector Oid, V.Vector FieldDescription))
|
||||
-> ByteString
|
||||
-> IO (Either Error (Vector Oid, Vector FieldDescription))
|
||||
describeStatement conn stmt = do
|
||||
sendEncode conn $
|
||||
encodeClientMessage (Parse sname (StatementSQL stmt) [])
|
||||
@ -135,7 +139,7 @@ describeStatement conn stmt = do
|
||||
sname = StatementName ""
|
||||
parseMessages msgs = case msgs of
|
||||
[ParameterDescription params, NoData]
|
||||
-> pure $ Right (params, V.empty)
|
||||
-> pure $ Right (params, mempty)
|
||||
[ParameterDescription params, RowDescription fields]
|
||||
-> pure $ Right (params, fields)
|
||||
xs -> maybe
|
||||
@ -160,5 +164,6 @@ findFirstError [] = Nothing
|
||||
findFirstError (ErrorResponse desc : _) = Just desc
|
||||
findFirstError (_ : xs) = findFirstError xs
|
||||
|
||||
{-# INLINE readChan #-}
|
||||
readChan :: TQueue a -> IO a
|
||||
readChan = atomically . readTQueue
|
||||
|
@ -1,10 +1,20 @@
|
||||
module Database.PostgreSQL.Driver.StatementStorage where
|
||||
module Database.PostgreSQL.Driver.StatementStorage
|
||||
( StatementStorage
|
||||
, CachePolicy(..)
|
||||
, newStatementStorage
|
||||
, lookupStatement
|
||||
, storeStatement
|
||||
, getCacheSize
|
||||
, defaultStatementName
|
||||
) where
|
||||
|
||||
import qualified Data.HashTable.IO as H
|
||||
import qualified Data.ByteString as B
|
||||
import Data.Monoid ((<>))
|
||||
import Data.IORef (IORef, newIORef, readIORef, writeIORef)
|
||||
import Data.Word (Word)
|
||||
|
||||
import Data.ByteString (ByteString)
|
||||
import Data.ByteString.Char8 (pack)
|
||||
import Data.Word (Word)
|
||||
import Data.IORef
|
||||
import qualified Data.HashTable.IO as H
|
||||
|
||||
import Database.PostgreSQL.Protocol.Types
|
||||
|
||||
@ -21,16 +31,17 @@ data CachePolicy
|
||||
newStatementStorage :: IO StatementStorage
|
||||
newStatementStorage = StatementStorage <$> H.new <*> newIORef 0
|
||||
|
||||
{-# INLINE lookupStatement #-}
|
||||
lookupStatement :: StatementStorage -> StatementSQL -> IO (Maybe StatementName)
|
||||
lookupStatement (StatementStorage table _) = H.lookup table
|
||||
|
||||
-- TODO place right name
|
||||
-- TODO info about exceptions and mask
|
||||
{-# INLINE storeStatement #-}
|
||||
storeStatement :: StatementStorage -> StatementSQL -> IO StatementName
|
||||
storeStatement (StatementStorage table counter) stmt = do
|
||||
n <- readIORef counter
|
||||
writeIORef counter $ n + 1
|
||||
let name = StatementName . pack $ show n
|
||||
let name = StatementName . (statementPrefix <>) . pack $ show n
|
||||
H.insert table stmt name
|
||||
pure name
|
||||
|
||||
@ -40,3 +51,6 @@ getCacheSize (StatementStorage _ counter) = readIORef counter
|
||||
defaultStatementName :: StatementName
|
||||
defaultStatementName = StatementName ""
|
||||
|
||||
statementPrefix :: ByteString
|
||||
statementPrefix = "_pw_statement_"
|
||||
|
||||
|
@ -1,5 +1,3 @@
|
||||
{-# language LambdaCase #-}
|
||||
|
||||
module Database.PostgreSQL.Protocol.Codecs.Numeric
|
||||
( scientificToNumeric
|
||||
, numericToScientific
|
||||
|
Loading…
Reference in New Issue
Block a user