diff --git a/postgres-wire.cabal b/postgres-wire.cabal index 6a973fa..97af4b7 100644 --- a/postgres-wire.cabal +++ b/postgres-wire.cabal @@ -16,12 +16,14 @@ cabal-version: >=1.10 library hs-source-dirs: src exposed-modules: Database.PostgreSQL.Protocol + , Database.PostgreSQL.Connection + , Database.PostgreSQL.Settings + , Database.PostgreSQL.StatementStorage + , Database.PostgreSQL.Types + , Database.PostgreSQL.Protocol.Types - , Database.PostgreSQL.Protocol.Settings - , Database.PostgreSQL.Protocol.Connection , Database.PostgreSQL.Protocol.Encoders , Database.PostgreSQL.Protocol.Decoders - , Database.PostgreSQL.Protocol.StatementStorage build-depends: base >= 4.7 && < 5 , bytestring , socket diff --git a/src/Database/PostgreSQL/Protocol/Connection.hs b/src/Database/PostgreSQL/Connection.hs similarity index 69% rename from src/Database/PostgreSQL/Protocol/Connection.hs rename to src/Database/PostgreSQL/Connection.hs index 170143e..e59dc0e 100644 --- a/src/Database/PostgreSQL/Protocol/Connection.hs +++ b/src/Database/PostgreSQL/Connection.hs @@ -5,7 +5,7 @@ {-# language ExistentialQuantification #-} {-# language TypeSynonymInstances #-} {-# language FlexibleInstances #-} -module Database.PostgreSQL.Protocol.Connection where +module Database.PostgreSQL.Connection where import qualified Data.ByteString as B @@ -30,11 +30,12 @@ import System.Socket.Family.Unix import Data.Time.Clock.POSIX import Control.Concurrent.Chan.Unagi -import Database.PostgreSQL.Protocol.Settings import Database.PostgreSQL.Protocol.Encoders import Database.PostgreSQL.Protocol.Decoders import Database.PostgreSQL.Protocol.Types -import Database.PostgreSQL.Protocol.StatementStorage +import Database.PostgreSQL.Settings +import Database.PostgreSQL.StatementStorage +import Database.PostgreSQL.Types type UnixSocket = Socket Unix Stream Unix @@ -42,11 +43,25 @@ type UnixSocket = Socket Unix Stream Unix data Connection = Connection { connSocket :: UnixSocket , connReceiverThread :: ThreadId - , connOutChan :: OutChan ServerMessage + -- Chan for only data messages + , connDataOutChan :: OutChan (Either Error DataMessage) + -- Chan for all messages that filter + , connAllOutChan :: OutChan ServerMessage , connStatementStorage :: StatementStorage , connParameters :: ConnectionParameters } +newtype ServerMessageFilter = ServerMessageFilter (ServerMessage -> Bool) + +type NotificationHandler = Notification -> IO () + +-- All possible errors +data Error + = PostgresError ErrorDesc + | ImpossibleError + +data DataMessage = DataMessage B.ByteString + address :: SocketAddress Unix address = fromJust $ socketAddressUnixPath "/var/run/postgresql/.s.PGSQL.5432" @@ -85,15 +100,11 @@ consStartupMessage stg = StartupMessage sendStartMessage :: UnixSocket -> StartMessage -> IO () sendStartMessage sock msg = void $ do let smsg = toStrict . toLazyByteString $ encodeStartMessage msg - -- putStrLn "sending message:" - -- print smsg send sock smsg mempty sendMessage :: UnixSocket -> ClientMessage -> IO () sendMessage sock msg = void $ do let smsg = toStrict . toLazyByteString $ encodeClientMessage msg - -- putStrLn "sending message:" - -- print smsg send sock smsg mempty readAuthMessage :: B.ByteString -> IO () @@ -107,33 +118,54 @@ readAuthMessage s = receiverThread :: UnixSocket -> InChan ServerMessage -> IO () receiverThread sock chan = forever $ do r <- receive sock 4096 mempty + print r go r where decoder = runGetIncremental decodeServerMessage go str = case pushChunk decoder str of BG.Done rest _ v -> do - print v - writeChan chan v + putStrLn $ "Received: " ++ show v unless (B.null rest) $ go rest BG.Partial _ -> error "Partial" BG.Fail _ _ e -> error e + dispatch :: ServerMessage -> IO () + -- dont receiving at this phase + dispatch (BackendKeyData _ _) = pure () + dispatch (BindComplete) = pure () + dispatch CloseComplete = pure () + -- maybe return command result too + dispatch (CommandComplete _) = pure () + dispatch r@(DataRow _) = writeChan chan r + -- TODO throw error here + dispatch EmptyQueryResponse = pure () + -- TODO throw error here + dispatch (ErrorResponse desc) = pure () + -- TODO + dispatch NoData = pure () + dispatch (NoticeResponse _) = pure () + -- TODO handle notifications + dispatch (NotificationResponse n) = pure () + -- Ignore here ? + dispatch (ParameterDescription _) = pure () + dispatch (ParameterStatus _ _) = pure () + dispatch (ParseComplete) = pure () + dispatch (PortalSuspended) = pure () + dispatch (ReadForQuery _) = pure () + dispatch (RowDescription _) = pure () data Query = Query { qStatement :: B.ByteString , qOids :: V.Vector Oid , qValues :: V.Vector B.ByteString + , qParamsFormat :: Format + , qResultFormat :: Format } deriving (Show) -query1 = Query "SELECT $1 + $2" [Oid 23, Oid 23] ["1", "3"] -query2 = Query "SELECT $1 + $2" [Oid 23, Oid 23] ["2", "3"] -query3 = Query "SELECT $1 + $2" [Oid 23, Oid 23] ["3", "3"] -query4 = Query "SELECT $1 + $2" [Oid 23, Oid 23] ["4", "3"] -query5 = Query "SELECT * FROM a where v > $1 + $2 LIMIT 100" [Oid 23, Oid 23] ["5", "3"] --- query1 = QQuery "test1" "select sum(v) from a" [] [] --- query2 = QQuery "test2" "select sum(v) from a" [] [] --- query3 = QQuery "test3" "select sum(v) from a" [] [] --- query4 = QQuery "test4" "select sum(v) from a" [] [] --- query5 = QQuery "test5" "select sum(v) from a" [] [] +query1 = Query "SELECT $1 + $2" [Oid 23, Oid 23] ["1", "3"] Text Text +query2 = Query "SELECT $1 + $2" [Oid 23, Oid 23] ["2", "3"] Text Text +query3 = Query "SELECT $1 + $2" [Oid 23, Oid 23] ["3", "3"] Text Text +query4 = Query "SELECT $1 + $2" [Oid 23, Oid 23] ["4", "3"] Text Text +-- query5 = Query "SELECT * FROM a whereee v > $1 + $2 LIMIT 100" [Oid 23, Oid 23] ["5", "3"] sendBatch :: Connection -> [Query] -> IO () sendBatch conn qs = do @@ -145,7 +177,8 @@ sendBatch conn qs = do let sname = StatementName "" pname = PortalName "" sendMessage s $ Parse sname (StatementSQL $ qStatement q) (qOids q) - sendMessage s $ Bind pname sname Text (qValues q) Text + sendMessage s $ + Bind pname sname (qParamsFormat q) (qValues q) (qResultFormat q) sendMessage s $ Execute pname noLimitToReceive @@ -170,6 +203,9 @@ test = do -- readNextData :: Connection -> IO Data? -- readNextData = undefined -- +-- readNextServerMessage ? +-- +-- -- Simple Queries support or maybe dont support it -- because single text query may be send through extended protocol -- may be support for all standalone queries diff --git a/src/Database/PostgreSQL/Protocol/Types.hs b/src/Database/PostgreSQL/Protocol/Types.hs index 1167153..dedd51c 100644 --- a/src/Database/PostgreSQL/Protocol/Types.hs +++ b/src/Database/PostgreSQL/Protocol/Types.hs @@ -1,5 +1,14 @@ module Database.PostgreSQL.Protocol.Types where +-- TODO +-- * COPY subprotocol commands +-- +-- * function call, is deprecated by postgres +-- * AuthenticationKerberosV5 IS deprecated by postgres +-- * AuthenticationSCMCredential IS deprecated since postgres 9.1 +-- * bind command can have different formats for parameters and results +-- but we assume that there will be one format for all. + import Data.Word (Word32, Word8) import Data.Int (Int32, Int16) import Data.Hashable (Hashable) @@ -51,15 +60,6 @@ data CommandResult | CommandOk deriving (Show) --- | Parameters of the current connection. --- We store only the parameters that cannot change after startup. --- For more information about additional parameters see documentation. -data ConnectionParameters = ConnectionParameters - { paramServerVersion :: ServerVersion - , paramServerEncoding :: ByteString -- ^ character set name - , paramIntegerDatetimes :: Bool -- ^ True if integer datetimes used - } deriving (Show) - -- | Server version contains major, minor, revision numbers. data ServerVersion = ServerVersion Word8 Word8 Word8 @@ -226,12 +226,3 @@ data NoticeDesc = NoticeDesc , noticeSourceRoutine :: Maybe ByteString } deriving (Show) --- TODO --- * COPY subprotocol commands --- * function call, is deprecated by postgres --- * AuthenticationKerberosV5 IS deprecated by postgres --- * AuthenticationSCMCredential IS deprecated since postgres 9.1 --- * NOTICE bind command can have different formats for parameters and results --- but we assume that there will be one format for all. --- * We dont store parameters of connection that may change after startup - diff --git a/src/Database/PostgreSQL/Protocol/Session.hs b/src/Database/PostgreSQL/Session.hs similarity index 100% rename from src/Database/PostgreSQL/Protocol/Session.hs rename to src/Database/PostgreSQL/Session.hs diff --git a/src/Database/PostgreSQL/Protocol/Settings.hs b/src/Database/PostgreSQL/Settings.hs similarity index 92% rename from src/Database/PostgreSQL/Protocol/Settings.hs rename to src/Database/PostgreSQL/Settings.hs index 383d7ee..a50fe96 100644 --- a/src/Database/PostgreSQL/Protocol/Settings.hs +++ b/src/Database/PostgreSQL/Settings.hs @@ -1,6 +1,6 @@ {-# language OverloadedStrings #-} -module Database.PostgreSQL.Protocol.Settings where +module Database.PostgreSQL.Settings where import Data.Word (Word16) import Data.ByteString (ByteString) diff --git a/src/Database/PostgreSQL/Protocol/StatementStorage.hs b/src/Database/PostgreSQL/StatementStorage.hs similarity index 93% rename from src/Database/PostgreSQL/Protocol/StatementStorage.hs rename to src/Database/PostgreSQL/StatementStorage.hs index 6f83229..b27ce7c 100644 --- a/src/Database/PostgreSQL/Protocol/StatementStorage.hs +++ b/src/Database/PostgreSQL/StatementStorage.hs @@ -1,4 +1,4 @@ -module Database.PostgreSQL.Protocol.StatementStorage where +module Database.PostgreSQL.StatementStorage where import qualified Data.HashTable.IO as H import qualified Data.ByteString as B diff --git a/src/Database/PostgreSQL/Types.hs b/src/Database/PostgreSQL/Types.hs new file mode 100644 index 0000000..f62c3e0 --- /dev/null +++ b/src/Database/PostgreSQL/Types.hs @@ -0,0 +1,18 @@ +{- + * We dont store parameters of connection that may change after startup +-} +module Database.PostgreSQL.Types where + +import Data.ByteString (ByteString) + +import Database.PostgreSQL.Protocol.Types + +-- | Parameters of the current connection. +-- We store only the parameters that cannot change after startup. +-- For more information about additional parameters see documentation. +data ConnectionParameters = ConnectionParameters + { paramServerVersion :: ServerVersion + , paramServerEncoding :: ByteString -- ^ character set name + , paramIntegerDatetimes :: Bool -- ^ True if integer datetimes used + } deriving (Show) +