diff --git a/postgres-wire.cabal b/postgres-wire.cabal index 7a8ff74..1b3945b 100644 --- a/postgres-wire.cabal +++ b/postgres-wire.cabal @@ -58,6 +58,7 @@ library BangPatterns OverloadedStrings GeneralizedNewtypeDeriving + LambdaCase cc-options: -O2 -Wall test-suite postgres-wire-test-connection diff --git a/src/Database/PostgreSQL/Driver/Connection.hs b/src/Database/PostgreSQL/Driver/Connection.hs index b334acd..4163cc4 100644 --- a/src/Database/PostgreSQL/Driver/Connection.hs +++ b/src/Database/PostgreSQL/Driver/Connection.hs @@ -56,11 +56,11 @@ import Database.PostgreSQL.Driver.RawConnection -- | Public -- Connection parametrized by message type in chan. data AbsConnection mt = AbsConnection - { connRawConnection :: RawConnection - , connReceiverThread :: Weak ThreadId - , connStatementStorage :: StatementStorage - , connParameters :: ConnectionParameters - , connOutChan :: TQueue (Either ReceiverException mt) + { connRawConnection :: !RawConnection + , connReceiverThread :: !(Weak ThreadId) + , connStatementStorage :: !StatementStorage + , connParameters :: !ConnectionParameters + , connOutChan :: !(TQueue (Either ReceiverException mt)) } type Connection = AbsConnection DataMessage @@ -122,15 +122,18 @@ connectCommon' settings msgFilter = connectWith settings $ \rawConn params -> -- Low-level sending functions +{-# INLINE sendStartMessage #-} sendStartMessage :: RawConnection -> StartMessage -> IO () sendStartMessage rawConn msg = void $ rSend rawConn . runEncode $ encodeStartMessage msg -- Only for testings and simple queries +{-# INLINE sendMessage #-} sendMessage :: RawConnection -> ClientMessage -> IO () sendMessage rawConn msg = void $ rSend rawConn . runEncode $ encodeClientMessage msg +{-# INLINE sendEncode #-} sendEncode :: AbsConnection c -> Encode -> IO () sendEncode conn = void . rSend (connRawConnection conn) . runEncode @@ -290,6 +293,11 @@ receiverThreadCommon rawConn chan msgFilter ntfHandler = go "" dispatchIfNotification (NotificationResponse ntf) handler = handler ntf dispatchIfNotification _ _ = pure () +-- | Helper to read from queue. +{-# INLINE writeChan #-} +writeChan :: TQueue a -> a -> IO () +writeChan q = atomically . writeTQueue q + defaultNotificationHandler :: NotificationHandler defaultNotificationHandler = const $ pure () @@ -332,7 +340,3 @@ defaultFilter msg = case msg of -- as result for `describe` message RowDescription{} -> True --- | Helper to read from queue. -writeChan :: TQueue a -> a -> IO () -writeChan q = atomically . writeTQueue q - diff --git a/src/Database/PostgreSQL/Driver/Query.hs b/src/Database/PostgreSQL/Driver/Query.hs index 0c9610e..fc1a934 100644 --- a/src/Database/PostgreSQL/Driver/Query.hs +++ b/src/Database/PostgreSQL/Driver/Query.hs @@ -12,13 +12,12 @@ module Database.PostgreSQL.Driver.Query , collectUntilReadyForQuery ) where -import Data.Foldable -import Data.Monoid -import Data.Bifunctor -import qualified Data.Vector as V -import qualified Data.ByteString as B import Control.Concurrent.STM.TQueue (TQueue, readTQueue ) -import Control.Concurrent.STM (atomically) +import Control.Concurrent.STM (atomically) +import Data.Foldable (fold) +import Data.Monoid ((<>)) +import Data.ByteString (ByteString) +import Data.Vector (Vector) import Database.PostgreSQL.Protocol.Encoders import Database.PostgreSQL.Protocol.Store.Encode @@ -31,26 +30,30 @@ import Database.PostgreSQL.Driver.StatementStorage -- Public data Query = Query - { qStatement :: B.ByteString - , qValues :: [(Oid, Maybe Encode)] - , qParamsFormat :: Format - , qResultFormat :: Format - , qCachePolicy :: CachePolicy + { qStatement :: !ByteString + , qValues :: ![(Oid, Maybe Encode)] + , qParamsFormat :: !Format + , qResultFormat :: !Format + , qCachePolicy :: !CachePolicy } deriving (Show) -- | Public +{- INLINE sendBatchAndFlush #-} sendBatchAndFlush :: Connection -> [Query] -> IO () sendBatchAndFlush = sendBatchEndBy Flush -- | Public +{-# INLINE sendBatchAndSync #-} sendBatchAndSync :: Connection -> [Query] -> IO () sendBatchAndSync = sendBatchEndBy Sync -- | Public +{-# INLINE sendSync #-} sendSync :: Connection -> IO () sendSync conn = sendEncode conn $ encodeClientMessage Sync -- | Public +{-# INLINABLE readNextData #-} readNextData :: Connection -> IO (Either Error DataRows) readNextData conn = readChan (connOutChan conn) >>= @@ -62,6 +65,7 @@ readNextData conn = DataReady -> throwIncorrectUsage "Expected DataRow message, but got ReadyForQuery" +{-# INLINABLE waitReadyForQuery #-} waitReadyForQuery :: Connection -> IO (Either Error ()) waitReadyForQuery conn = readChan (connOutChan conn) >>= @@ -77,6 +81,7 @@ waitReadyForQuery conn = DataReady -> pure $ Right () -- Helper +{-# INLINE sendBatchEndBy #-} sendBatchEndBy :: ClientMessage -> Connection -> [Query] -> IO () sendBatchEndBy msg conn qs = do batch <- constructBatch conn qs @@ -90,28 +95,27 @@ constructBatch conn = fmap fold . traverse constructSingle pname = PortalName "" constructSingle q = do let stmtSQL = StatementSQL $ qStatement q - (sname, parseMessage) <- case qCachePolicy q of - AlwaysCache -> do - mName <- lookupStatement storage stmtSQL - case mName of - Nothing -> do - newName <- storeStatement storage stmtSQL - pure (newName, encodeClientMessage $ - Parse newName stmtSQL (fst <$> qValues q)) - Just name -> pure (name, mempty) - NeverCache -> do - let newName = defaultStatementName - pure (newName, encodeClientMessage $ - Parse newName stmtSQL (fst <$> qValues q)) - let bindMessage = encodeClientMessage $ - Bind pname sname (qParamsFormat q) (snd <$> qValues q) + (stmtName, needParse) <- case qCachePolicy q of + AlwaysCache -> lookupStatement storage stmtSQL >>= \case + Nothing -> do + newName <- storeStatement storage stmtSQL + pure (newName, True) + Just name -> + pure (name, False) + NeverCache -> pure (defaultStatementName, True) + let parseMessage = if needParse + then encodeClientMessage $ + Parse stmtName stmtSQL (fst <$> qValues q) + else mempty + bindMessage = encodeClientMessage $ + Bind pname stmtName (qParamsFormat q) (snd <$> qValues q) (qResultFormat q) executeMessage = encodeClientMessage $ Execute pname noLimitToReceive pure $ parseMessage <> bindMessage <> executeMessage -- | Public -sendSimpleQuery :: ConnectionCommon -> B.ByteString -> IO (Either Error ()) +sendSimpleQuery :: ConnectionCommon -> ByteString -> IO (Either Error ()) sendSimpleQuery conn q = do sendMessage (connRawConnection conn) $ SimpleQuery (StatementSQL q) (checkErrors =<<) <$> collectUntilReadyForQuery conn @@ -122,8 +126,8 @@ sendSimpleQuery conn q = do -- | Public describeStatement :: ConnectionCommon - -> B.ByteString - -> IO (Either Error (V.Vector Oid, V.Vector FieldDescription)) + -> ByteString + -> IO (Either Error (Vector Oid, Vector FieldDescription)) describeStatement conn stmt = do sendEncode conn $ encodeClientMessage (Parse sname (StatementSQL stmt) []) @@ -135,7 +139,7 @@ describeStatement conn stmt = do sname = StatementName "" parseMessages msgs = case msgs of [ParameterDescription params, NoData] - -> pure $ Right (params, V.empty) + -> pure $ Right (params, mempty) [ParameterDescription params, RowDescription fields] -> pure $ Right (params, fields) xs -> maybe @@ -160,5 +164,6 @@ findFirstError [] = Nothing findFirstError (ErrorResponse desc : _) = Just desc findFirstError (_ : xs) = findFirstError xs +{-# INLINE readChan #-} readChan :: TQueue a -> IO a readChan = atomically . readTQueue diff --git a/src/Database/PostgreSQL/Driver/StatementStorage.hs b/src/Database/PostgreSQL/Driver/StatementStorage.hs index 079034e..f24a5aa 100644 --- a/src/Database/PostgreSQL/Driver/StatementStorage.hs +++ b/src/Database/PostgreSQL/Driver/StatementStorage.hs @@ -1,10 +1,20 @@ -module Database.PostgreSQL.Driver.StatementStorage where +module Database.PostgreSQL.Driver.StatementStorage + ( StatementStorage + , CachePolicy(..) + , newStatementStorage + , lookupStatement + , storeStatement + , getCacheSize + , defaultStatementName + ) where -import qualified Data.HashTable.IO as H -import qualified Data.ByteString as B +import Data.Monoid ((<>)) +import Data.IORef (IORef, newIORef, readIORef, writeIORef) +import Data.Word (Word) + +import Data.ByteString (ByteString) import Data.ByteString.Char8 (pack) -import Data.Word (Word) -import Data.IORef +import qualified Data.HashTable.IO as H import Database.PostgreSQL.Protocol.Types @@ -21,16 +31,17 @@ data CachePolicy newStatementStorage :: IO StatementStorage newStatementStorage = StatementStorage <$> H.new <*> newIORef 0 +{-# INLINE lookupStatement #-} lookupStatement :: StatementStorage -> StatementSQL -> IO (Maybe StatementName) lookupStatement (StatementStorage table _) = H.lookup table --- TODO place right name -- TODO info about exceptions and mask +{-# INLINE storeStatement #-} storeStatement :: StatementStorage -> StatementSQL -> IO StatementName storeStatement (StatementStorage table counter) stmt = do n <- readIORef counter writeIORef counter $ n + 1 - let name = StatementName . pack $ show n + let name = StatementName . (statementPrefix <>) . pack $ show n H.insert table stmt name pure name @@ -40,3 +51,6 @@ getCacheSize (StatementStorage _ counter) = readIORef counter defaultStatementName :: StatementName defaultStatementName = StatementName "" +statementPrefix :: ByteString +statementPrefix = "_pw_statement_" + diff --git a/src/Database/PostgreSQL/Protocol/Codecs/Numeric.hs b/src/Database/PostgreSQL/Protocol/Codecs/Numeric.hs index d2930b5..3a201a1 100644 --- a/src/Database/PostgreSQL/Protocol/Codecs/Numeric.hs +++ b/src/Database/PostgreSQL/Protocol/Codecs/Numeric.hs @@ -1,5 +1,3 @@ -{-# language LambdaCase #-} - module Database.PostgreSQL.Protocol.Codecs.Numeric ( scientificToNumeric , numericToScientific