mirror of
https://github.com/postgres-haskell/postgres-wire.git
synced 2024-11-26 09:33:46 +03:00
Simple session works
This commit is contained in:
parent
7f994e083b
commit
fb0de46ab6
@ -20,6 +20,7 @@ library
|
||||
, Database.PostgreSQL.Settings
|
||||
, Database.PostgreSQL.StatementStorage
|
||||
, Database.PostgreSQL.Types
|
||||
, Database.PostgreSQL.Session
|
||||
|
||||
, Database.PostgreSQL.Protocol.Types
|
||||
, Database.PostgreSQL.Protocol.Encoders
|
||||
@ -35,6 +36,7 @@ library
|
||||
, hashtables
|
||||
, unagi-chan
|
||||
, unordered-containers
|
||||
, postgresql-binary
|
||||
default-language: Haskell2010
|
||||
default-extensions:
|
||||
OverloadedStrings
|
||||
|
@ -217,52 +217,23 @@ data Query = Query
|
||||
, qResultFormat :: Format
|
||||
} deriving (Show)
|
||||
|
||||
query1 = Query "SELECT $1 + $2" [Oid 23, Oid 23] ["1", "3"] Text Text
|
||||
query2 = Query "SELECT $1 + $2" [Oid 23, Oid 23] ["a", "3"] Text Text
|
||||
query3 = Query "SELECT $1 + $2" [Oid 23, Oid 23] ["3", "3"] Text Text
|
||||
query4 = Query "SELECT $1 + $2" [Oid 23, Oid 23] ["4", "3"] Text Text
|
||||
|
||||
sendBatch :: Connection -> [Query] -> IO ()
|
||||
sendBatch conn qs = do
|
||||
traverse sendSingle $ take 5 qs
|
||||
sendMessage s Sync
|
||||
sendBatch conn = traverse_ sendSingle
|
||||
where
|
||||
s = connSocket conn
|
||||
sname = StatementName ""
|
||||
pname = PortalName ""
|
||||
sendSingle q = do
|
||||
let sname = StatementName ""
|
||||
pname = PortalName ""
|
||||
sendMessage s $ Parse sname (StatementSQL $ qStatement q) (qOids q)
|
||||
sendMessage s $
|
||||
Bind pname sname (qParamsFormat q) (qValues q) (qResultFormat q)
|
||||
sendMessage s $ Execute pname noLimitToReceive
|
||||
|
||||
sendSync :: Connection -> IO ()
|
||||
sendSync conn = sendMessage (connSocket conn) Sync
|
||||
|
||||
test :: IO ()
|
||||
test = do
|
||||
c <- connect defaultConnectionSettings
|
||||
sendBatch c queries
|
||||
readResults c $ length queries
|
||||
readReadyForQuery c >>= print
|
||||
close c
|
||||
where
|
||||
queries = [query1, query2, query3, query4 ]
|
||||
readResults c 0 = pure ()
|
||||
readResults c n = do
|
||||
r <- readNextData c
|
||||
print r
|
||||
case r of
|
||||
Left _ -> pure ()
|
||||
Right _ -> readResults c $ n - 1
|
||||
|
||||
-- sendBatchAndSync :: IsQuery a => [a] -> Connection -> IO ()
|
||||
-- sendBatchAndSync = undefined
|
||||
|
||||
-- sendBatchAndFlush :: IsQuery a => [a] -> Connection -> IO ()
|
||||
-- sendBatchAndFlush = undefined
|
||||
|
||||
-- internal helper
|
||||
-- sendBatch :: IsQuery a => [a] -> Connection -> IO ()
|
||||
-- sendBatch = undefined
|
||||
sendFlush :: Connection -> IO ()
|
||||
sendFlush conn = sendMessage (connSocket conn) Flush
|
||||
|
||||
readNextData :: Connection -> IO (Either Error DataMessage)
|
||||
readNextData conn = readChan $ connOutDataChan conn
|
||||
@ -308,6 +279,40 @@ describeStatement conn stmt = do
|
||||
xs -> maybe (error "Impossible happened") (Left . PostgresError )
|
||||
$ findFirstError xs
|
||||
|
||||
query1 = Query "SELECT $1 + $2" [Oid 23, Oid 23] ["1", "3"] Text Text
|
||||
query2 = Query "SELECT $1 + $2" [Oid 23, Oid 23] ["a", "3"] Text Text
|
||||
query3 = Query "SELECT $1 + $2" [Oid 23, Oid 23] ["3", "3"] Text Text
|
||||
query4 = Query "SELECT $1 + $2" [Oid 23, Oid 23] ["4", "3"] Text Text
|
||||
|
||||
|
||||
test :: IO ()
|
||||
test = do
|
||||
c <- connect defaultConnectionSettings
|
||||
sendBatch c queries
|
||||
readResults c $ length queries
|
||||
readReadyForQuery c >>= print
|
||||
close c
|
||||
where
|
||||
queries = [query1, query2, query3, query4 ]
|
||||
readResults c 0 = pure ()
|
||||
readResults c n = do
|
||||
r <- readNextData c
|
||||
print r
|
||||
case r of
|
||||
Left _ -> pure ()
|
||||
Right _ -> readResults c $ n - 1
|
||||
|
||||
-- sendBatchAndSync :: IsQuery a => [a] -> Connection -> IO ()
|
||||
-- sendBatchAndSync = undefined
|
||||
|
||||
-- sendBatchAndFlush :: IsQuery a => [a] -> Connection -> IO ()
|
||||
-- sendBatchAndFlush = undefined
|
||||
|
||||
-- internal helper
|
||||
-- sendBatch :: IsQuery a => [a] -> Connection -> IO ()
|
||||
-- sendBatch = undefined
|
||||
|
||||
|
||||
testDescribe1 :: IO ()
|
||||
testDescribe1 = do
|
||||
c <- connect defaultConnectionSettings
|
||||
|
@ -1,10 +1,33 @@
|
||||
{-# language ApplicativeDo #-}
|
||||
{-# language OverloadedLists #-}
|
||||
{-# language OverloadedStrings #-}
|
||||
{-# language ExistentialQuantification #-}
|
||||
{-# language TypeSynonymInstances #-}
|
||||
{-# language FlexibleInstances #-}
|
||||
module Database.PostgreSQL.Session where
|
||||
|
||||
import Control.Monad
|
||||
import Control.Applicative
|
||||
import Data.Monoid
|
||||
import Data.Int
|
||||
import Data.ByteString (ByteString)
|
||||
import Data.Either
|
||||
import qualified Data.Vector as V
|
||||
|
||||
import PostgreSQL.Binary.Encoder (int8_int64, run)
|
||||
import qualified PostgreSQL.Binary.Decoder as D(int, run)
|
||||
|
||||
import Database.PostgreSQL.Protocol.Types
|
||||
import Database.PostgreSQL.Connection
|
||||
import Database.PostgreSQL.Settings
|
||||
|
||||
data Count = One | Many
|
||||
deriving (Eq, Show)
|
||||
|
||||
data Session a
|
||||
= Done a
|
||||
| forall r . Decode r => Receive (r -> Session a)
|
||||
| Send Count [Request] (Session a)
|
||||
| Send Count [Query] (Session a)
|
||||
|
||||
instance Functor Session where
|
||||
f `fmap` (Done a) = Done $ f a
|
||||
@ -43,21 +66,70 @@ instance Monad Session where
|
||||
|
||||
(>>) = (*>)
|
||||
|
||||
runSession :: Show a => Connection -> Session a -> IO a
|
||||
runSession conn@(Connection sock _ chan) = go
|
||||
where
|
||||
go (Done x) = do
|
||||
putStrLn $ "Return " ++ show x
|
||||
pure x
|
||||
go (Receive f) = do
|
||||
putStrLn "Receiving"
|
||||
-- TODO receive here
|
||||
-- x <- receive
|
||||
x <- getLine
|
||||
go (f $ decode x)
|
||||
go (Send _ rs c) = do
|
||||
putStrLn "Sending requests "
|
||||
-- TODO send requests here in batch
|
||||
sendBatch conn rs
|
||||
go c
|
||||
class Encode a where
|
||||
encode :: a -> ByteString
|
||||
getOid :: a -> Oid
|
||||
|
||||
class Decode a where
|
||||
decode :: ByteString -> a
|
||||
|
||||
instance Encode Int64 where
|
||||
encode = run int8_int64
|
||||
getOid _ = Oid 20
|
||||
|
||||
instance Decode Int64 where
|
||||
decode = fromRight . D.run D.int
|
||||
where
|
||||
fromRight (Right v) = v
|
||||
fromRight _ = error "bad fromRight"
|
||||
|
||||
data SessionQuery a b = SessionQuery { sqStatement :: ByteString }
|
||||
deriving (Show)
|
||||
|
||||
query :: (Encode a, Decode b) => SessionQuery a b -> a -> Session b
|
||||
query sq val =
|
||||
let q = Query { qStatement = sqStatement sq
|
||||
, qOids = [getOid val]
|
||||
, qValues = [encode val]
|
||||
, qParamsFormat = Binary
|
||||
, qResultFormat = Binary }
|
||||
in Send One [q] $ Receive Done
|
||||
|
||||
runSession :: Show a => Connection -> Session a -> IO (Either Error a)
|
||||
runSession conn = go 0
|
||||
where
|
||||
go n (Done x) = do
|
||||
putStrLn $ "Return " ++ show x
|
||||
when (n > 0) $ void $ sendSync conn >> readReadyForQuery conn
|
||||
pure $ Right x
|
||||
go n (Receive f) = do
|
||||
putStrLn "Receiving"
|
||||
r <- readNextData conn
|
||||
case r of
|
||||
Left e -> pure $ Left e
|
||||
Right (DataMessage rows) -> go n (f $ decode $ V.head $ head rows)
|
||||
go n (Send _ qs c) = do
|
||||
putStrLn "Sending requests "
|
||||
sendBatch conn qs
|
||||
sendFlush conn
|
||||
go (n + 1) c
|
||||
|
||||
q1 :: SessionQuery Int64 Int64
|
||||
q1 = SessionQuery "SELECT $1"
|
||||
|
||||
q2 :: SessionQuery Int64 Int64
|
||||
q2 = SessionQuery "SELECT count(*) from a where v < $1"
|
||||
|
||||
q3 :: SessionQuery Int64 Int64
|
||||
q3 = SessionQuery "SELECT 5 + $1"
|
||||
|
||||
testSession :: IO ()
|
||||
testSession = do
|
||||
c <- connect defaultConnectionSettings
|
||||
r <- runSession c $ do
|
||||
b <- query q1 10
|
||||
a <- query q2 b
|
||||
query q3 a
|
||||
print r
|
||||
close c
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user