diff --git a/postgres-wire.cabal b/postgres-wire.cabal index 94d4480..0c89fa5 100644 --- a/postgres-wire.cabal +++ b/postgres-wire.cabal @@ -65,9 +65,11 @@ test-suite postgres-wire-test hs-source-dirs: tests main-is: test.hs other-modules: Connection + , Driver , Protocol build-depends: base , postgres-wire + , bytestring , tasty , tasty-hunit ghc-options: -threaded -rtsopts -with-rtsopts=-N diff --git a/src/Database/PostgreSQL/Driver/Connection.hs b/src/Database/PostgreSQL/Driver/Connection.hs index fd29660..3d12ecd 100644 --- a/src/Database/PostgreSQL/Driver/Connection.hs +++ b/src/Database/PostgreSQL/Driver/Connection.hs @@ -71,7 +71,7 @@ data AuthError deriving (Show) data DataMessage = DataMessage [V.Vector B.ByteString] - deriving (Show) + deriving (Show, Eq) -- | Abstraction over raw socket connection or tls connection data RawConnection = RawConnection @@ -122,6 +122,7 @@ constructRawConnection s = RawConnection , rReceive = \n -> Socket.receive s n mempty } +-- | Public connect :: ConnectionSettings -> IO Connection connect settings = connectWith settings defaultFilter @@ -206,26 +207,12 @@ authorize rawConn settings = do handshakeTls :: RawConnection -> IO () handshakeTls _ = pure () +-- | Public close :: Connection -> IO () close conn = do killThread $ connReceiverThread conn rClose $ connRawConnection conn -consStartupMessage :: ConnectionSettings -> StartMessage -consStartupMessage stg = StartupMessage - (Username $ settingsUser stg) (DatabaseName $ settingsDatabase stg) - -sendStartMessage :: RawConnection -> StartMessage -> IO () -sendStartMessage rawConn msg = void $ do - let smsg = toStrict . toLazyByteString $ encodeStartMessage msg - rSend rawConn smsg - -sendMessage :: RawConnection -> ClientMessage -> IO () -sendMessage rawConn msg = void $ do - let smsg = toStrict . toLazyByteString $ encodeClientMessage msg - rSend rawConn smsg - - receiverThread :: ServerMessageFilter -> RawConnection @@ -320,6 +307,22 @@ defaultFilter msg = case msg of -- as result for `describe` message RowDescription{} -> True +consStartupMessage :: ConnectionSettings -> StartMessage +consStartupMessage stg = StartupMessage + (Username $ settingsUser stg) (DatabaseName $ settingsDatabase stg) + +sendStartMessage :: RawConnection -> StartMessage -> IO () +sendStartMessage rawConn msg = void $ do + let smsg = toStrict . toLazyByteString $ encodeStartMessage msg + rSend rawConn smsg + +sendMessage :: RawConnection -> ClientMessage -> IO () +sendMessage rawConn msg = void $ do + let smsg = toStrict . toLazyByteString $ encodeClientMessage msg + rSend rawConn smsg + + +-- Public data Query = Query { qStatement :: B.ByteString , qOids :: V.Vector Oid @@ -349,11 +352,9 @@ sendBatchAndSync conn qs = sendBatch conn qs >> sendSync conn sendBatchAndFlush :: Connection -> [Query] -> IO () sendBatchAndFlush conn qs = sendBatch conn qs >> sendFlush conn --- | Public sendSync :: Connection -> IO () sendSync conn = sendMessage (connRawConnection conn) Sync --- | Public sendFlush :: Connection -> IO () sendFlush conn = sendMessage (connRawConnection conn) Flush diff --git a/tests/Driver.hs b/tests/Driver.hs new file mode 100644 index 0000000..6a26c25 --- /dev/null +++ b/tests/Driver.hs @@ -0,0 +1,60 @@ +module Driver where + +import Data.Monoid ((<>)) +import Data.Foldable +import Control.Monad +import qualified Data.ByteString as B +import qualified Data.ByteString.Char8 as BS + +import Test.Tasty +import Test.Tasty.HUnit + +import Database.PostgreSQL.Driver.Connection +import Database.PostgreSQL.Protocol.Types + +import Connection + +makeQuery1 :: B.ByteString -> Query +makeQuery1 n = Query "SELECT $1" [Oid 23] [n] Text Text + +makeQuery2 :: B.ByteString -> B.ByteString -> Query +makeQuery2 n1 n2 = Query "SELECT $1 + $2" [Oid 23, Oid 23] [n1, n2] Text Text + +testDriver = testGroup "Driver" + [ testCase "Single batch" testBatch + , testCase "Two batches" testTwoBatches + ] + +fromRight (Right v) = v +fromRight _ = error "fromRight" + +testBatch :: IO () +testBatch = withConnection $ \c -> do + let a = "5" + b = "3" + sendBatchAndSync c [makeQuery1 a, makeQuery1 b] + readReadyForQuery c + + r1 <- readNextData c + r2 <- readNextData c + DataMessage [[a]] @=? fromRight r1 + DataMessage [[b]] @=? fromRight r2 + +testTwoBatches :: IO () +testTwoBatches = withConnection $ \c -> do + let a = 7 + b = 2 + sendBatchAndFlush c [ makeQuery1 (BS.pack (show a)) + , makeQuery1 (BS.pack (show b))] + r1 <- fromMessage . fromRight <$> readNextData c + r2 <- fromMessage . fromRight <$> readNextData c + + sendBatchAndSync c [makeQuery2 r1 r2] + r <- readNextData c + readReadyForQuery c + + DataMessage [[BS.pack (show $ a + b)]] @=? fromRight r + where + fromMessage (DataMessage [[v]]) = v + fromMessage _ = error "from message" + diff --git a/tests/test.hs b/tests/test.hs index 26678cc..b0ba778 100644 --- a/tests/test.hs +++ b/tests/test.hs @@ -1,9 +1,11 @@ import Test.Tasty (defaultMain, testGroup) import Protocol +import Driver main :: IO () main = defaultMain $ testGroup "Postgres-wire" [ testProtocolMessages + , testDriver ]