Refactor/add cross product APIs

This commit is contained in:
Harendra Kumar 2021-03-24 01:39:42 +05:30
parent a6e8e23062
commit c5b2ad3740
4 changed files with 54 additions and 52 deletions

View File

@ -97,8 +97,10 @@ module Streamly.Data.Unfold
, zipWithM
, zipWith
-- ** Cross Product
, crossWith
-- ** Nesting
, cross
, many
-- ** Exceptions

View File

@ -161,14 +161,17 @@ module Streamly.Internal.Data.Unfold
, zipWith
, teeZipWith
-- ** Nesting
-- ** Cross product
, crossWithM
, crossWith
, cross
, apply
-- ** Nesting
, ConcatState (..)
, many
, concatMapM
, bind
, outerProduct
-- ** Exceptions
, gbracket_
@ -896,45 +899,6 @@ teeZipWith :: Monad m
=> (a -> b -> c) -> Unfold m x a -> Unfold m x b -> Unfold m x c
teeZipWith f unf1 unf2 = lmap (\x -> (x,x)) $ zipWith f unf1 unf2
-------------------------------------------------------------------------------
-- Nested
-------------------------------------------------------------------------------
data OuterProductState s1 s2 sy x y =
OuterProductOuter s1 y | OuterProductInner s1 sy s2 x
-- XXX this can be written in terms of "cross".
-- XXX Remove this in favor of cross?
--
-- | Create an outer product (vector product or cartesian product) of the
-- output streams of two unfolds.
--
{-# INLINE_NORMAL outerProduct #-}
outerProduct :: Monad m
=> Unfold m a b -> Unfold m c d -> Unfold m (a, c) (b, d)
outerProduct (Unfold step1 inject1) (Unfold step2 inject2) = Unfold step inject
where
inject (x, y) = do
s1 <- inject1 x
return $ OuterProductOuter s1 y
{-# INLINE_LATE step #-}
step (OuterProductOuter st1 sy) = do
r <- step1 st1
case r of
Yield x s -> do
s2 <- inject2 sy
return $ Skip (OuterProductInner s sy s2 x)
Skip s -> return $ Skip (OuterProductOuter s sy)
Stop -> return Stop
step (OuterProductInner ost sy ist x) = do
r <- step2 ist
return $ case r of
Yield y s -> Yield (x, y) (OuterProductInner ost sy s x)
Skip s -> Skip (OuterProductInner ost sy s x)
Stop -> Skip (OuterProductOuter ost sy)
------------------------------------------------------------------------------
-- Exceptions
------------------------------------------------------------------------------

View File

@ -24,14 +24,18 @@ module Streamly.Internal.Data.Unfold.Type
, ConcatState (..)
, many
-- Applicative
, apSequence
, apDiscardSnd
, crossWithM
, crossWith
, cross
, apply
, bind
-- Monad
, concatMapM
, concatMap
, bind
, zipWithM
, zipWith
@ -144,11 +148,13 @@ apDiscardSnd (Unfold _step1 _inject1) (Unfold _step2 _inject2) = undefined
data Cross a s1 b s2 = CrossOuter a s1 | CrossInner a s1 b s2
-- | Create a cross product (vector product or cartesian product) of the
-- output streams of two unfolds.
-- output streams of two unfolds using a monadic combining function.
--
{-# INLINE_NORMAL cross #-}
cross :: Monad m => Unfold m a b -> Unfold m a c -> Unfold m a (b, c)
cross (Unfold step1 inject1) (Unfold step2 inject2) = Unfold step inject
-- /Pre-release/
{-# INLINE_NORMAL crossWithM #-}
crossWithM :: Monad m =>
(b -> c -> m d) -> Unfold m a b -> Unfold m a c -> Unfold m a d
crossWithM f (Unfold step1 inject1) (Unfold step2 inject2) = Unfold step inject
where
@ -168,10 +174,36 @@ cross (Unfold step1 inject1) (Unfold step2 inject2) = Unfold step inject
step (CrossInner a s1 b s2) = do
r <- step2 s2
return $ case r of
Yield c s -> Yield (b, c) (CrossInner a s1 b s)
Skip s -> Skip (CrossInner a s1 b s)
Stop -> Skip (CrossOuter a s1)
case r of
Yield c s -> f b c >>= \d -> return $ Yield d (CrossInner a s1 b s)
Skip s -> return $ Skip (CrossInner a s1 b s)
Stop -> return $ Skip (CrossOuter a s1)
-- | Like 'crossWithM' but uses a pure combining function.
--
-- > crossWith f = crossWithM (\b c -> return $ f b c)
--
-- /Pre-release/
{-# INLINE crossWith #-}
crossWith :: Monad m =>
(b -> c -> d) -> Unfold m a b -> Unfold m a c -> Unfold m a d
crossWith f = crossWithM (\b c -> return $ f b c)
-- | See 'crossWith'.
--
-- > cross = crossWith (,)
--
-- To cross the streams from a tuple we can write:
--
-- @
-- crossProduct :: Monad m => Unfold m a b -> Unfold m c d -> Unfold m (a, c) (b, d)
-- crossProduct u1 u2 = cross (lmap fst u1) (lmap snd u2)
-- @
--
-- /Pre-release/
{-# INLINE_NORMAL cross #-}
cross :: Monad m => Unfold m a b -> Unfold m a c -> Unfold m a (b, c)
cross = crossWith (,)
apply :: Monad m => Unfold m a (b -> c) -> Unfold m a b -> Unfold m a c
apply u1 u2 = fmap (\(a, b) -> a b) (cross u1 u2)

View File

@ -335,10 +335,14 @@ outerProduct :: Bool
outerProduct =
let unf1 = UF.enumerateFromToIntegral 10
unf2 = UF.enumerateFromToIntegral 20
unf = UF.outerProduct unf1 unf2
unf = crossProduct unf1 unf2
lst = [(a, b) :: (Int, Int) | a <- [0 .. 10], b <- [0 .. 20]]
in testUnfold unf ((0, 0) :: (Int, Int)) lst
where
crossProduct u1 u2 = UF.cross (UF.lmap fst u1) (UF.lmap snd u2)
concatMapM :: Bool
concatMapM =
let inner b =