Add inner monad transformations

This commit is contained in:
Harendra Kumar 2019-11-10 03:49:13 +05:30
parent 8351c936c3
commit 9ad7fca9c7
4 changed files with 176 additions and 11 deletions

View File

@ -296,6 +296,10 @@ main =
, benchIOSink "tee" (Ops.transformTeeMapM serially 4)
, benchIOSink "zip" (Ops.transformZipMapM serially 4)
]
, bgroup "transformer"
[ benchIOSrc serially "evalState" Ops.evalStateT
, benchIOSrc serially "withState" Ops.withState
]
, bgroup "transformation"
[ benchIOSink "scanl" (Ops.scan 1)
, benchIOSink "scanl1'" (Ops.scanl1' 1)

View File

@ -25,6 +25,7 @@ module Streamly.Benchmark.Prelude where
import Control.DeepSeq (NFData)
import Control.Monad (when)
import Control.Monad.IO.Class (MonadIO)
import Control.Monad.State.Strict (StateT, get, put)
import Data.Functor.Identity (Identity, runIdentity)
import Data.Maybe (fromJust)
import GHC.Generics (Generic)
@ -167,6 +168,19 @@ sourceUnfoldrM n = S.unfoldrM step n
then return Nothing
else return (Just (cnt, cnt + 1))
{-# INLINE sourceUnfoldrState #-}
sourceUnfoldrState :: (S.IsStream t, S.MonadAsync m)
=> Int -> t (StateT Int m) Int
sourceUnfoldrState n = S.unfoldrM step n
where
step cnt =
if cnt > n + value
then return Nothing
else do
s <- get
put (s + 1)
return (Just (s, cnt + 1))
{-# INLINE sourceUnfoldrMN #-}
sourceUnfoldrMN :: (S.IsStream t, S.MonadAsync m) => Int -> Int -> t m Int
sourceUnfoldrMN upto start = S.unfoldrM step start
@ -208,6 +222,15 @@ runStream = S.drain
{-# INLINE toList #-}
toList :: Monad m => Stream m Int -> m [Int]
{-# INLINE evalStateT #-}
evalStateT :: S.MonadAsync m => Int -> Stream m Int
evalStateT n = Internal.evalStateT 0 (sourceUnfoldrState n)
{-# INLINE withState #-}
withState :: S.MonadAsync m => Int -> Stream m Int
withState n =
Internal.evalStateT (0 :: Int) (Internal.liftInner (sourceUnfoldrM n))
{-# INLINE head #-}
{-# INLINE last #-}
{-# INLINE maximum #-}

View File

@ -133,8 +133,6 @@ module Streamly.Internal.Prelude
-- * Transformation
, transform
, hoist
, generally
-- ** Mapping
, Serial.map
@ -376,6 +374,17 @@ module Streamly.Internal.Prelude
, finally
, handle
-- * Generalize Inner Monad
, hoist
, generally
-- * Transform Inner Monad
, liftInner
, runReaderT
, evalStateT
, usingStateT
, runStateT
-- * Diagnostics
, inspectMode
@ -399,6 +408,8 @@ import Control.Exception (Exception)
import Control.Monad (void)
import Control.Monad.Catch (MonadCatch)
import Control.Monad.IO.Class (MonadIO(..))
import Control.Monad.Reader (ReaderT)
import Control.Monad.State.Strict (StateT)
import Control.Monad.Trans (MonadTrans(..))
import Data.Functor.Identity (Identity (..))
import Data.Heap (Entry(..))
@ -1471,15 +1482,6 @@ toPureRev = foldl' (flip K.cons) K.nil
transform :: (IsStream t, Monad m) => Pipe m a b -> t m a -> t m b
transform pipe xs = fromStreamD $ D.transform pipe (toStreamD xs)
{-# INLINE hoist #-}
hoist :: (Monad m, Monad n)
=> (forall x. m x -> n x) -> SerialT m a -> SerialT n a
hoist f xs = fromStreamS $ S.hoist f (toStreamS xs)
{-# INLINE generally #-}
generally :: (IsStream t, Monad m) => t Identity a -> t m a
generally xs = fromStreamS $ S.hoist (return . runIdentity) (toStreamS xs)
------------------------------------------------------------------------------
-- Transformation by Folding (Scans)
------------------------------------------------------------------------------
@ -3696,3 +3698,84 @@ handle :: (IsStream t, MonadCatch m, Exception e)
=> (e -> t m a) -> t m a -> t m a
handle handler xs =
D.fromStreamD $ D.handle (\e -> D.toStreamD $ handler e) $ D.toStreamD xs
------------------------------------------------------------------------------
-- Generalize the underlying monad
------------------------------------------------------------------------------
-- | Transform the inner monad of a stream using a natural transformation.
--
-- / Internal/
--
{-# INLINE hoist #-}
hoist :: (Monad m, Monad n)
=> (forall x. m x -> n x) -> SerialT m a -> SerialT n a
hoist f xs = fromStreamS $ S.hoist f (toStreamS xs)
-- | Generalize the inner monad of the stream from 'Identity' to any monad.
--
-- / Internal/
--
{-# INLINE generally #-}
generally :: (IsStream t, Monad m) => t Identity a -> t m a
generally xs = fromStreamS $ S.hoist (return . runIdentity) (toStreamS xs)
------------------------------------------------------------------------------
-- Add and remove a monad transformer
------------------------------------------------------------------------------
-- | Lift the inner monad of a stream using a monad transformer.
--
-- / Internal/
--
{-# INLINE liftInner #-}
liftInner :: (Monad m, IsStream t, MonadTrans tr, Monad (tr m))
=> t m a -> t (tr m) a
liftInner xs = fromStreamD $ D.liftInner (toStreamD xs)
-- | Evaluate the inner monad of a stream as 'ReaderT'.
--
-- / Internal/
--
{-# INLINE runReaderT #-}
runReaderT :: (IsStream t, Monad m) => s -> t (ReaderT s m) a -> t m a
runReaderT s xs = fromStreamD $ D.runReaderT s (toStreamD xs)
-- | Evaluate the inner monad of a stream as 'StateT'.
--
-- This is supported only for 'SerialT' as concurrent state updation may not be
-- safe.
--
-- / Internal/
--
{-# INLINE evalStateT #-}
evalStateT :: Monad m => s -> SerialT (StateT s m) a -> SerialT m a
evalStateT s xs = fromStreamD $ D.evalStateT s (toStreamD xs)
-- | Run a stateful (StateT) stream transformation using a given state.
--
-- This is supported only for 'SerialT' as concurrent state updation may not be
-- safe.
--
-- / Internal/
--
{-# INLINE usingStateT #-}
usingStateT
:: Monad m
=> s
-> (SerialT (StateT s m) a -> SerialT (StateT s m) a)
-> SerialT m a
-> SerialT m a
usingStateT s f xs = evalStateT s $ f $ liftInner xs
-- | Evaluate the inner monad of a stream as 'StateT' and emit the resulting
-- state and value pair after each step.
--
-- This is supported only for 'SerialT' as concurrent state updation may not be
-- safe.
--
-- / Internal/
--
{-# INLINE runStateT #-}
runStateT :: Monad m => s -> SerialT (StateT s m) a -> SerialT m (s, a)
runStateT s xs = fromStreamD $ D.runStateT s (toStreamD xs)

View File

@ -188,9 +188,15 @@ module Streamly.Streams.StreamD
, toListRev
, toStreamK
, toStreamD
, hoist
, generally
, liftInner
, runReaderT
, evalStateT
, runStateT
-- * Transformation
, transform
@ -290,6 +296,8 @@ import Control.Exception (Exception, SomeException)
import Control.Monad (void)
import Control.Monad.Catch (MonadCatch)
import Control.Monad.IO.Class (MonadIO(..))
import Control.Monad.Reader (ReaderT)
import Control.Monad.State.Strict (StateT)
import Control.Monad.Trans (MonadTrans(lift))
import Data.Bits (shiftR, shiftL, (.|.), (.&.))
import Data.Functor.Identity (Identity(..))
@ -312,6 +320,8 @@ import Prelude
reverse)
import qualified Control.Monad.Catch as MC
import qualified Control.Monad.Reader as Reader
import qualified Control.Monad.State.Strict as State
import Streamly.Internal.Memory.Array.Types (Array(..))
import Streamly.Internal.Data.Fold.Types (Fold(..))
@ -650,6 +660,51 @@ hoist f (Stream step state) = (Stream step' state)
generally :: Monad m => Stream Identity a -> Stream m a
generally = hoist (return . runIdentity)
{-# INLINE_NORMAL liftInner #-}
liftInner :: (Monad m, MonadTrans t, Monad (t m))
=> Stream m a -> Stream (t m) a
liftInner (Stream step state) = Stream step' state
where
step' gst st = do
r <- lift $ step (adaptState gst) st
return $ case r of
Yield x s -> Yield x s
Skip s -> Skip s
Stop -> Stop
{-# INLINE_NORMAL runReaderT #-}
runReaderT :: Monad m => s -> Stream (ReaderT s m) a -> Stream m a
runReaderT sval (Stream step state) = Stream step' state
where
step' gst st = do
r <- Reader.runReaderT (step (adaptState gst) st) sval
return $ case r of
Yield x s -> Yield x s
Skip s -> Skip s
Stop -> Stop
{-# INLINE_NORMAL evalStateT #-}
evalStateT :: Monad m => s -> Stream (StateT s m) a -> Stream m a
evalStateT sval (Stream step state) = Stream step' (state, sval)
where
step' gst (st, sv) = do
(r, sv') <- State.runStateT (step (adaptState gst) st) sv
return $ case r of
Yield x s -> Yield x (s, sv')
Skip s -> Skip (s, sv')
Stop -> Stop
{-# INLINE_NORMAL runStateT #-}
runStateT :: Monad m => s -> Stream (StateT s m) a -> Stream m (s, a)
runStateT sval (Stream step state) = Stream step' (state, sval)
where
step' gst (st, sv) = do
(r, sv') <- State.runStateT (step (adaptState gst) st) sv
return $ case r of
Yield x s -> Yield (sv', x) (s, sv')
Skip s -> Skip (s, sv')
Stop -> Stop
------------------------------------------------------------------------------
-- Elimination by Folds
------------------------------------------------------------------------------