Add foldEither and foldConcat

This commit is contained in:
Harendra Kumar 2022-10-26 04:20:25 +05:30
parent 0717458443
commit 8d806c7029
4 changed files with 174 additions and 39 deletions

View File

@ -315,9 +315,6 @@ splitOnSuffix byte s =
-- Elimination - Running folds
-------------------------------------------------------------------------------
-- XXX This should be written using CPS (as foldK) if we want it to scale wrt
-- to the number of times it can be called on the same stream.
--
{-# INLINE_NORMAL foldBreakD #-}
foldBreakD :: forall m a b. (MonadIO m, Unbox a) =>
Fold m a b -> D.Stream m (Array a) -> m (b, D.Stream m (Array a))

View File

@ -23,6 +23,10 @@ module Streamly.Internal.Data.Stream.Bottom
, fold
, foldContinue
, foldBreak
, foldBreak2
, foldEither
, foldEither2
, foldConcat
-- * Scans
, smapM
@ -61,10 +65,13 @@ where
#include "inline.hs"
import Control.Monad.IO.Class (MonadIO(..))
import GHC.Types (SPEC(..))
import Streamly.Internal.Data.Fold.Type (Fold (..))
import Streamly.Internal.Data.Time.Units (AbsTime, RelTime64, addToAbsTime64)
import Streamly.Internal.Data.Unboxed (Unbox)
import Streamly.Internal.Data.Producer.Type (Producer(..))
import Streamly.Internal.System.IO (defaultChunkSize)
import Streamly.Internal.Data.SVar.Type (defState)
import qualified Streamly.Internal.Data.Array.Unboxed.Type as A
import qualified Streamly.Internal.Data.Fold as Fold
@ -214,21 +221,127 @@ fold fl strm = D.fold fl $ D.fromStreamK $ toStreamK strm
--
{-# INLINE foldBreak #-}
foldBreak :: Monad m => Fold m a b -> Stream m a -> m (b, Stream m a)
{-
-- XXX This shows quadratic performance when used recursively perhaps because
-- of StreamK to StreamD conversions not getting eliminated sue to recursion.
foldBreak fl (Stream strm) = fmap f $ D.foldBreak fl $ D.fromStreamK strm
where
f (b, str) = (b, Stream (D.toStreamK str))
-}
foldBreak fl strm = fmap f $ K.foldBreak fl (toStreamK strm)
where
f (b, str) = (b, fromStreamK str)
-- XXX The quadratic slowdown in recursive use is because recursive function
-- cannot be inlined and StreamD/StreamK conversions pile up and cannot be
-- eliminated by rewrite rules.
-- | Like 'foldBreak' but fuses.
--
-- /Note:/ Unlike 'foldBreak', recursive application on the resulting stream
-- would lead to quadratic slowdown. If you need recursion with fusion (within
-- one iteration of recursion) use StreamD.foldBreak directly.
--
-- /Internal/
{-# INLINE foldBreak2 #-}
foldBreak2 :: Monad m => Fold m a b -> Stream m a -> m (b, Stream m a)
foldBreak2 fl strm = fmap f $ D.foldBreak fl $ toStreamD strm
where
f (b, str) = (b, fromStreamD str)
-- | Fold resulting in either breaking the stream or continuation of the fold
-- Instead of supplying the input stream in one go we can run the fold multiple
-- times each time supplying the next segment of the input stream. If the fold
-- has not yet finished it returns a fold that can be run again otherwise it
-- returns the fold result and the residual stream.
--
-- /Internal/
{-# INLINE foldEither #-}
foldEither :: Monad m =>
Fold m a b -> Stream m a -> m (Either (Fold m a b) (b, Stream m a))
foldEither fl strm = fmap (fmap f) $ K.foldEither fl $ toStreamK strm
where
f (b, str) = (b, fromStreamK str)
-- | Like 'foldEither' but fuses. However, recursive application on resulting
-- stream would lead to quadratic slowdown.
--
-- /Internal/
{-# INLINE foldEither2 #-}
foldEither2 :: Monad m =>
Fold m a b -> Stream m a -> m (Either (Fold m a b) (b, Stream m a))
foldEither2 fl strm = fmap (fmap f) $ D.foldEither fl $ toStreamD strm
where
f (b, str) = (b, fromStreamD str)
-- XXX Array folds can be implemented using this.
-- foldContainers? Specialized to foldArrays.
-- | Generate streams from individual elements of a stream and fold the
-- concatenation of those streams using the supplied fold. Return the result of
-- the fold and residual stream.
--
-- For example, this can be used to efficiently fold an Array Word8 stream
-- using Word8 folds.
--
-- The outer stream forces CPS to allow scalable appends and the inner stream
-- forces direct style for stream fusion.
--
-- /Internal/
{-# INLINE foldConcat #-}
foldConcat :: Monad m =>
Producer m a b -> Fold m b c -> Stream m a -> m (c, Stream m a)
foldConcat
(Producer pstep pinject pextract)
(Fold fstep begin done)
stream = do
res <- begin
case res of
Fold.Partial fs -> go fs streamK
Fold.Done fb -> return (fb, fromStreamK streamK)
where
streamK = toStreamK stream
go !acc m1 = do
let stop = do
r <- done acc
return (r, fromStreamK K.nil)
single a = do
st <- pinject a
res <- go1 SPEC acc st
case res of
Left fs -> do
r <- done fs
return (r, fromStreamK K.nil)
Right (b, s) -> do
x <- pextract s
return (b, fromStreamK (K.fromPure x))
yieldk a r = do
st <- pinject a
res <- go1 SPEC acc st
case res of
Left fs -> go fs r
Right (b, s) -> do
x <- pextract s
return (b, fromStreamK (x `K.cons` r))
in K.foldStream defState yieldk single stop m1
{-# INLINE go1 #-}
go1 !_ !fs st = do
r <- pstep st
case r of
D.Yield x s -> do
res <- fstep fs x
case res of
Fold.Done b -> return $ Right (b, s)
Fold.Partial fs1 -> go1 SPEC fs1 s
D.Skip s -> go1 SPEC fs s
D.Stop -> return $ Left fs
------------------------------------------------------------------------------
-- Transformation
------------------------------------------------------------------------------

View File

@ -44,7 +44,7 @@ module Streamly.Internal.Data.Stream.StreamD.Type
, fold
, foldBreak
, foldContinue
, foldStream
, foldEither
-- * Right Folds
, foldrT
@ -282,13 +282,14 @@ fold fld strm = do
(b, _) <- foldBreak fld strm
return b
{-# INLINE_NORMAL foldBreak #-}
foldBreak :: Monad m => Fold m a b -> Stream m a -> m (b, Stream m a)
foldBreak (Fold fstep begin done) (UnStream step state) = do
{-# INLINE_NORMAL foldEither #-}
foldEither :: Monad m =>
Fold m a b -> Stream m a -> m (Either (Fold m a b) (b, Stream m a))
foldEither (Fold fstep begin done) (UnStream step state) = do
res <- begin
case res of
FL.Partial fs -> go SPEC fs state
FL.Done fb -> return $! (fb, Stream step state)
FL.Done fb -> return $! Right (fb, Stream step state)
where
@ -299,12 +300,28 @@ foldBreak (Fold fstep begin done) (UnStream step state) = do
Yield x s -> do
res <- fstep fs x
case res of
FL.Done b -> return $! (b, Stream step s)
FL.Done b -> return $! Right (b, Stream step s)
FL.Partial fs1 -> go SPEC fs1 s
Skip s -> go SPEC fs s
Stop -> do
b <- done fs
return $! (b, Stream (\ _ _ -> return Stop) ())
Stop -> return $! Left (Fold fstep (return $ FL.Partial fs) done)
{-# INLINE_NORMAL foldBreak #-}
foldBreak :: Monad m => Fold m a b -> Stream m a -> m (b, Stream m a)
foldBreak fld strm = do
r <- foldEither fld strm
case r of
Right res -> return res
Left (Fold _ initial extract) -> do
res <- initial
case res of
FL.Done _ -> error "foldBreak: unreachable state"
FL.Partial s -> do
b <- extract s
return (b, nil)
where
nil = Stream (\_ _ -> return Stop) ()
-- | If the fold finishes before the stream, we can detect that the fold is
-- done by checking if the initial action returns Done. But the remaining
@ -334,16 +351,6 @@ foldContinue (Fold fstep finitial fextract) (Stream sstep state) =
Skip s -> go SPEC fs s
Stop -> return $ FL.Partial fs
-- | Returns when either the fold or the stream finishes. If the Fold finishes
-- first we can check that using Fold.done. If the fold is not done then stream
-- would be nil.
--
-- /Unimplemented/
{-# INLINE_NORMAL foldStream #-}
foldStream :: -- Monad m =>
Fold m a b -> Stream m a -> m (Fold m a b, Stream m a)
foldStream = undefined
------------------------------------------------------------------------------
-- Right Folds
------------------------------------------------------------------------------

View File

@ -83,6 +83,7 @@ module Streamly.Internal.Data.Stream.StreamK
, foldlMx'
, fold
, foldBreak
, foldEither
, parseBreak
-- ** Specialized Folds
@ -180,6 +181,7 @@ where
import Control.Monad.Catch (MonadThrow, throwM)
import Control.Monad.Trans.Class (MonadTrans(lift))
import Control.Monad (void, join)
import Streamly.Internal.Data.Fold.Type (Fold(..))
import Streamly.Internal.Data.SVar.Type (adaptState, defState)
import qualified Streamly.Internal.Data.Fold.Type as FL
@ -333,30 +335,46 @@ fold (FL.Fold step begin done) m = do
FL.Done b1 -> return b1
in foldStream defState yieldk single stop m1
{-# INLINE foldBreak #-}
foldBreak :: Monad m => FL.Fold m a b -> Stream m a -> m (b, Stream m a)
foldBreak (FL.Fold step begin done) m = do
{-# INLINE foldEither #-}
foldEither :: Monad m =>
Fold m a b -> Stream m a -> m (Either (Fold m a b) (b, Stream m a))
foldEither (FL.Fold step begin done) m = do
res <- begin
case res of
FL.Partial fs -> go fs m
FL.Done fb -> return (fb, m)
FL.Done fb -> return $ Right (fb, m)
where
go !acc m1 =
let stop = (, nil) <$> done acc
let stop = return $ Left (Fold step (return $ FL.Partial acc) done)
single a =
step acc a
>>= \case
FL.Partial s -> (, nil) <$> done s
FL.Done b1 -> return (b1, nil)
FL.Partial s ->
return $ Left (Fold step (return $ FL.Partial s) done)
FL.Done b1 -> return $ Right (b1, nil)
yieldk a r =
step acc a
>>= \case
FL.Partial s -> go s r
FL.Done b1 -> return (b1, r)
FL.Done b1 -> return $ Right (b1, r)
in foldStream defState yieldk single stop m1
{-# INLINE foldBreak #-}
foldBreak :: Monad m => Fold m a b -> Stream m a -> m (b, Stream m a)
foldBreak fld strm = do
r <- foldEither fld strm
case r of
Right res -> return res
Left (Fold _ initial extract) -> do
res <- initial
case res of
FL.Done _ -> error "foldBreak: unreachable state"
FL.Partial s -> do
b <- extract s
return (b, nil)
-- | Like 'foldl'' but with a monadic step function.
{-# INLINE foldlM' #-}
foldlM' :: Monad m => (b -> a -> m b) -> m b -> Stream m a -> m b