From 9ad7fca9c743fa672897c46bebc02178efac2e53 Mon Sep 17 00:00:00 2001 From: Harendra Kumar Date: Sun, 10 Nov 2019 03:49:13 +0530 Subject: [PATCH] Add inner monad transformations --- benchmark/Linear.hs | 4 ++ src/Streamly/Benchmark/Prelude.hs | 23 +++++++ src/Streamly/Internal/Prelude.hs | 105 ++++++++++++++++++++++++++---- src/Streamly/Streams/StreamD.hs | 55 ++++++++++++++++ 4 files changed, 176 insertions(+), 11 deletions(-) diff --git a/benchmark/Linear.hs b/benchmark/Linear.hs index fa7901231..f6ead0a9f 100644 --- a/benchmark/Linear.hs +++ b/benchmark/Linear.hs @@ -296,6 +296,10 @@ main = , benchIOSink "tee" (Ops.transformTeeMapM serially 4) , benchIOSink "zip" (Ops.transformZipMapM serially 4) ] + , bgroup "transformer" + [ benchIOSrc serially "evalState" Ops.evalStateT + , benchIOSrc serially "withState" Ops.withState + ] , bgroup "transformation" [ benchIOSink "scanl" (Ops.scan 1) , benchIOSink "scanl1'" (Ops.scanl1' 1) diff --git a/src/Streamly/Benchmark/Prelude.hs b/src/Streamly/Benchmark/Prelude.hs index b27db2a0f..9c8a2d52d 100644 --- a/src/Streamly/Benchmark/Prelude.hs +++ b/src/Streamly/Benchmark/Prelude.hs @@ -25,6 +25,7 @@ module Streamly.Benchmark.Prelude where import Control.DeepSeq (NFData) import Control.Monad (when) import Control.Monad.IO.Class (MonadIO) +import Control.Monad.State.Strict (StateT, get, put) import Data.Functor.Identity (Identity, runIdentity) import Data.Maybe (fromJust) import GHC.Generics (Generic) @@ -167,6 +168,19 @@ sourceUnfoldrM n = S.unfoldrM step n then return Nothing else return (Just (cnt, cnt + 1)) +{-# INLINE sourceUnfoldrState #-} +sourceUnfoldrState :: (S.IsStream t, S.MonadAsync m) + => Int -> t (StateT Int m) Int +sourceUnfoldrState n = S.unfoldrM step n + where + step cnt = + if cnt > n + value + then return Nothing + else do + s <- get + put (s + 1) + return (Just (s, cnt + 1)) + {-# INLINE sourceUnfoldrMN #-} sourceUnfoldrMN :: (S.IsStream t, S.MonadAsync m) => Int -> Int -> t m Int sourceUnfoldrMN upto start = S.unfoldrM step start @@ -208,6 +222,15 @@ runStream = S.drain {-# INLINE toList #-} toList :: Monad m => Stream m Int -> m [Int] +{-# INLINE evalStateT #-} +evalStateT :: S.MonadAsync m => Int -> Stream m Int +evalStateT n = Internal.evalStateT 0 (sourceUnfoldrState n) + +{-# INLINE withState #-} +withState :: S.MonadAsync m => Int -> Stream m Int +withState n = + Internal.evalStateT (0 :: Int) (Internal.liftInner (sourceUnfoldrM n)) + {-# INLINE head #-} {-# INLINE last #-} {-# INLINE maximum #-} diff --git a/src/Streamly/Internal/Prelude.hs b/src/Streamly/Internal/Prelude.hs index e2d9e6c8c..6bc63b7a4 100644 --- a/src/Streamly/Internal/Prelude.hs +++ b/src/Streamly/Internal/Prelude.hs @@ -133,8 +133,6 @@ module Streamly.Internal.Prelude -- * Transformation , transform - , hoist - , generally -- ** Mapping , Serial.map @@ -376,6 +374,17 @@ module Streamly.Internal.Prelude , finally , handle + -- * Generalize Inner Monad + , hoist + , generally + + -- * Transform Inner Monad + , liftInner + , runReaderT + , evalStateT + , usingStateT + , runStateT + -- * Diagnostics , inspectMode @@ -399,6 +408,8 @@ import Control.Exception (Exception) import Control.Monad (void) import Control.Monad.Catch (MonadCatch) import Control.Monad.IO.Class (MonadIO(..)) +import Control.Monad.Reader (ReaderT) +import Control.Monad.State.Strict (StateT) import Control.Monad.Trans (MonadTrans(..)) import Data.Functor.Identity (Identity (..)) import Data.Heap (Entry(..)) @@ -1471,15 +1482,6 @@ toPureRev = foldl' (flip K.cons) K.nil transform :: (IsStream t, Monad m) => Pipe m a b -> t m a -> t m b transform pipe xs = fromStreamD $ D.transform pipe (toStreamD xs) -{-# INLINE hoist #-} -hoist :: (Monad m, Monad n) - => (forall x. m x -> n x) -> SerialT m a -> SerialT n a -hoist f xs = fromStreamS $ S.hoist f (toStreamS xs) - -{-# INLINE generally #-} -generally :: (IsStream t, Monad m) => t Identity a -> t m a -generally xs = fromStreamS $ S.hoist (return . runIdentity) (toStreamS xs) - ------------------------------------------------------------------------------ -- Transformation by Folding (Scans) ------------------------------------------------------------------------------ @@ -3696,3 +3698,84 @@ handle :: (IsStream t, MonadCatch m, Exception e) => (e -> t m a) -> t m a -> t m a handle handler xs = D.fromStreamD $ D.handle (\e -> D.toStreamD $ handler e) $ D.toStreamD xs + +------------------------------------------------------------------------------ +-- Generalize the underlying monad +------------------------------------------------------------------------------ + +-- | Transform the inner monad of a stream using a natural transformation. +-- +-- / Internal/ +-- +{-# INLINE hoist #-} +hoist :: (Monad m, Monad n) + => (forall x. m x -> n x) -> SerialT m a -> SerialT n a +hoist f xs = fromStreamS $ S.hoist f (toStreamS xs) + +-- | Generalize the inner monad of the stream from 'Identity' to any monad. +-- +-- / Internal/ +-- +{-# INLINE generally #-} +generally :: (IsStream t, Monad m) => t Identity a -> t m a +generally xs = fromStreamS $ S.hoist (return . runIdentity) (toStreamS xs) + +------------------------------------------------------------------------------ +-- Add and remove a monad transformer +------------------------------------------------------------------------------ + +-- | Lift the inner monad of a stream using a monad transformer. +-- +-- / Internal/ +-- +{-# INLINE liftInner #-} +liftInner :: (Monad m, IsStream t, MonadTrans tr, Monad (tr m)) + => t m a -> t (tr m) a +liftInner xs = fromStreamD $ D.liftInner (toStreamD xs) + +-- | Evaluate the inner monad of a stream as 'ReaderT'. +-- +-- / Internal/ +-- +{-# INLINE runReaderT #-} +runReaderT :: (IsStream t, Monad m) => s -> t (ReaderT s m) a -> t m a +runReaderT s xs = fromStreamD $ D.runReaderT s (toStreamD xs) + +-- | Evaluate the inner monad of a stream as 'StateT'. +-- +-- This is supported only for 'SerialT' as concurrent state updation may not be +-- safe. +-- +-- / Internal/ +-- +{-# INLINE evalStateT #-} +evalStateT :: Monad m => s -> SerialT (StateT s m) a -> SerialT m a +evalStateT s xs = fromStreamD $ D.evalStateT s (toStreamD xs) + +-- | Run a stateful (StateT) stream transformation using a given state. +-- +-- This is supported only for 'SerialT' as concurrent state updation may not be +-- safe. +-- +-- / Internal/ +-- +{-# INLINE usingStateT #-} +usingStateT + :: Monad m + => s + -> (SerialT (StateT s m) a -> SerialT (StateT s m) a) + -> SerialT m a + -> SerialT m a +usingStateT s f xs = evalStateT s $ f $ liftInner xs + +-- | Evaluate the inner monad of a stream as 'StateT' and emit the resulting +-- state and value pair after each step. +-- +-- This is supported only for 'SerialT' as concurrent state updation may not be +-- safe. +-- +-- / Internal/ +-- +{-# INLINE runStateT #-} +runStateT :: Monad m => s -> SerialT (StateT s m) a -> SerialT m (s, a) +runStateT s xs = fromStreamD $ D.runStateT s (toStreamD xs) diff --git a/src/Streamly/Streams/StreamD.hs b/src/Streamly/Streams/StreamD.hs index 4569b2533..6e851d7b7 100644 --- a/src/Streamly/Streams/StreamD.hs +++ b/src/Streamly/Streams/StreamD.hs @@ -188,9 +188,15 @@ module Streamly.Streams.StreamD , toListRev , toStreamK , toStreamD + , hoist , generally + , liftInner + , runReaderT + , evalStateT + , runStateT + -- * Transformation , transform @@ -290,6 +296,8 @@ import Control.Exception (Exception, SomeException) import Control.Monad (void) import Control.Monad.Catch (MonadCatch) import Control.Monad.IO.Class (MonadIO(..)) +import Control.Monad.Reader (ReaderT) +import Control.Monad.State.Strict (StateT) import Control.Monad.Trans (MonadTrans(lift)) import Data.Bits (shiftR, shiftL, (.|.), (.&.)) import Data.Functor.Identity (Identity(..)) @@ -312,6 +320,8 @@ import Prelude reverse) import qualified Control.Monad.Catch as MC +import qualified Control.Monad.Reader as Reader +import qualified Control.Monad.State.Strict as State import Streamly.Internal.Memory.Array.Types (Array(..)) import Streamly.Internal.Data.Fold.Types (Fold(..)) @@ -650,6 +660,51 @@ hoist f (Stream step state) = (Stream step' state) generally :: Monad m => Stream Identity a -> Stream m a generally = hoist (return . runIdentity) +{-# INLINE_NORMAL liftInner #-} +liftInner :: (Monad m, MonadTrans t, Monad (t m)) + => Stream m a -> Stream (t m) a +liftInner (Stream step state) = Stream step' state + where + step' gst st = do + r <- lift $ step (adaptState gst) st + return $ case r of + Yield x s -> Yield x s + Skip s -> Skip s + Stop -> Stop + +{-# INLINE_NORMAL runReaderT #-} +runReaderT :: Monad m => s -> Stream (ReaderT s m) a -> Stream m a +runReaderT sval (Stream step state) = Stream step' state + where + step' gst st = do + r <- Reader.runReaderT (step (adaptState gst) st) sval + return $ case r of + Yield x s -> Yield x s + Skip s -> Skip s + Stop -> Stop + +{-# INLINE_NORMAL evalStateT #-} +evalStateT :: Monad m => s -> Stream (StateT s m) a -> Stream m a +evalStateT sval (Stream step state) = Stream step' (state, sval) + where + step' gst (st, sv) = do + (r, sv') <- State.runStateT (step (adaptState gst) st) sv + return $ case r of + Yield x s -> Yield x (s, sv') + Skip s -> Skip (s, sv') + Stop -> Stop + +{-# INLINE_NORMAL runStateT #-} +runStateT :: Monad m => s -> Stream (StateT s m) a -> Stream m (s, a) +runStateT sval (Stream step state) = Stream step' (state, sval) + where + step' gst (st, sv) = do + (r, sv') <- State.runStateT (step (adaptState gst) st) sv + return $ case r of + Yield x s -> Yield (sv', x) (s, sv') + Skip s -> Skip (s, sv') + Stop -> Stop + ------------------------------------------------------------------------------ -- Elimination by Folds ------------------------------------------------------------------------------