Fix/simplify intersectBySorted

* Remove MonadIO and Eq constraints
* Simplify implementation
* Simplify tests
* Fix formatting
* Use longer benchmarks
This commit is contained in:
Harendra Kumar 2022-02-10 12:56:44 +05:30
parent 160393c8e0
commit ec20f5fdff
5 changed files with 73 additions and 90 deletions

View File

@ -417,57 +417,56 @@ o_n_space_monad value =
-- Joining
-------------------------------------------------------------------------------
toKvMap :: Int -> (Int, Int)
toKvMap p = (p, p)
toKv :: Int -> (Int, Int)
toKv p = (p, p)
{-# INLINE joinWith #-}
joinWith :: (S.MonadAsync m) =>
((Int -> Int -> Bool) -> SerialT m Int -> SerialT m Int -> SerialT m b)
-> Int
-> Int
-> Int
-> m ()
joinWith j val1 val2 i =
S.drain $ j (==) (sourceUnfoldrM val1 i) (sourceUnfoldrM val2 i)
joinWith j val i =
S.drain $ j (==) (sourceUnfoldrM val i) (sourceUnfoldrM val (val `div` 2))
{-# INLINE joinMapWith #-}
joinMapWith :: (S.MonadAsync m) =>
(SerialT m (Int, Int) -> SerialT m (Int, Int) -> SerialT m b)
-> Int
-> Int
-> Int
-> m ()
joinMapWith j val1 val2 i =
joinMapWith j val i =
S.drain
$ j
(fmap toKvMap (sourceUnfoldrM val1 i))
(fmap toKvMap (sourceUnfoldrM val2 i))
(fmap toKv (sourceUnfoldrM val i))
(fmap toKv (sourceUnfoldrM val (val `div` 2)))
o_n_heap_buffering :: Int -> [Benchmark]
o_n_heap_buffering value =
[ bgroup "buffered"
[
benchIOSrc1 "joinInner"
$ joinWith Internal.joinInner sqrtVal sqrtVal
benchIOSrc1 "joinInner (sqrtVal)"
$ joinWith Internal.joinInner sqrtVal
, benchIOSrc1 "joinInnerMap"
$ joinMapWith Internal.joinInnerMap sqrtVal sqrtVal
, benchIOSrc1 "joinLeft"
$ joinWith Internal.joinLeft sqrtVal sqrtVal
$ joinMapWith Internal.joinInnerMap halfVal
, benchIOSrc1 "joinLeft (sqrtVal)"
$ joinWith Internal.joinLeft sqrtVal
, benchIOSrc1 "joinLeftMap "
$ joinMapWith Internal.joinLeftMap sqrtVal sqrtVal
, benchIOSrc1 "joinOuter"
$ joinWith Internal.joinOuter sqrtVal sqrtVal
$ joinMapWith Internal.joinLeftMap halfVal
, benchIOSrc1 "joinOuter (sqrtVal)"
$ joinWith Internal.joinOuter sqrtVal
, benchIOSrc1 "joinOuterMap"
$ joinMapWith Internal.joinOuterMap sqrtVal sqrtVal
, benchIOSrc1 "intersectBy"
$ joinWith Internal.intersectBy sqrtVal sqrtVal
$ joinMapWith Internal.joinOuterMap halfVal
, benchIOSrc1 "intersectBy (sqrtVal)"
$ joinWith Internal.intersectBy sqrtVal
, benchIOSrc1 "intersectBySorted"
$ joinMapWith Internal.intersectBySorted sqrtVal sqrtVal
$ joinMapWith (Internal.intersectBySorted compare) halfVal
]
]
where
halfVal = value `div` 2
sqrtVal = round $ sqrt (fromIntegral value :: Double)
-------------------------------------------------------------------------------

View File

@ -28,7 +28,7 @@ module Streamly.Internal.Data.Stream.IsStream.Top
-- | These are not exactly set operations because streams are not
-- necessarily sets, they may have duplicated elements.
, intersectBy
, intersectBySorted
, intersectBySorted
, differenceBy
, mergeDifferenceBy
, unionBy
@ -65,7 +65,6 @@ import Streamly.Internal.Data.Stream.IsStream.Common (concatM)
import Streamly.Internal.Data.Stream.IsStream.Type
(IsStream(..), adapt, foldl', fromList)
import Streamly.Internal.Data.Stream.Serial (SerialT)
--import Streamly.Internal.Data.Stream.StreamD (fromStreamD, toStreamD)
import Streamly.Internal.Data.Time.Units (NanoSecond64(..), toRelTime64)
import qualified Data.List as List
@ -576,7 +575,7 @@ intersectBy eq s1 s2 =
xs <- Stream.toListRev $ Stream.uniqBy eq $ adapt s2
return $ Stream.filter (\x -> List.any (eq x) xs) s1
-- | Like 'intersectBy' but works only on sorted streams.
-- | Like 'intersectBy' but works only on streams sorted in ascending order.
--
-- Space: O(1)
--
@ -584,10 +583,12 @@ intersectBy eq s1 s2 =
--
-- /Pre-release/
{-# INLINE intersectBySorted #-}
intersectBySorted :: (IsStream t, MonadIO m, Eq a) =>
intersectBySorted :: (IsStream t, Monad m) =>
(a -> a -> Ordering) -> t m a -> t m a -> t m a
intersectBySorted eq s1 =
IsStream.fromStreamD . StreamD.intersectBySorted eq (IsStream.toStreamD s1) . IsStream.toStreamD
IsStream.fromStreamD
. StreamD.intersectBySorted eq (IsStream.toStreamD s1)
. IsStream.toStreamD
-- Roughly joinLeft s1 s2 = s1 `difference` s2 + s1 `intersection` s2

View File

@ -484,57 +484,47 @@ mergeBy
mergeBy cmp = mergeByM (\a b -> return $ cmp a b)
-------------------------------------------------------------------------------
-- Intersection of sorted streams ---------------------------------------------
-- Intersection of sorted streams
-------------------------------------------------------------------------------
-- Assuming the streams are sorted in ascending order
{-# INLINE_NORMAL intersectBySorted #-}
intersectBySorted
:: (MonadIO m, Eq a)
intersectBySorted :: Monad m
=> (a -> a -> Ordering) -> Stream m a -> Stream m a -> Stream m a
intersectBySorted cmp (Stream stepa ta) (Stream stepb tb) =
Stream step (Just ta, Just tb, Nothing, Nothing, Nothing)
Stream step
( ta -- left stream state
, tb -- right stream state
, Nothing -- left value
, Nothing -- right value
)
where
{-# INLINE_LATE step #-}
-- step 1
step gst (Just sa, sb, Nothing, b, Nothing) = do
{-# INLINE_LATE step #-}
-- step 1, fetch the first value
step gst (sa, sb, Nothing, b) = do
r <- stepa gst sa
return $ case r of
Yield a sa' -> Skip (Just sa', sb, Just a, b, Nothing)
Skip sa' -> Skip (Just sa', sb, Nothing, b, Nothing)
Yield a sa' -> Skip (sa', sb, Just a, b) -- step 2/3
Skip sa' -> Skip (sa', sb, Nothing, b)
Stop -> Stop
-- step 2
step gst (sa, Just sb, a, Nothing, Nothing) = do
-- step 2, fetch the second value
step gst (sa, sb, a@(Just _), Nothing) = do
r <- stepb gst sb
return $ case r of
Yield b sb' -> Skip (sa, Just sb', a, Just b, Nothing)
Skip sb' -> Skip (sa, Just sb', a, Nothing, Nothing)
Yield b sb' -> Skip (sa, sb', a, Just b) -- step 3
Skip sb' -> Skip (sa, sb', a, Nothing)
Stop -> Stop
-- step 3
-- both the values are available compare it
step _ (sa, sb, Just a, Just b, Nothing) = do
-- step 3, compare the two values
step _ (sa, sb, Just a, Just b) = do
let res = cmp a b
return $ case res of
GT -> Skip (sa, sb, Just a, Nothing, Nothing)
LT -> Skip (sa, sb, Nothing, Just b, Nothing)
EQ -> Yield a (sa, sb, Nothing, Just a, Just b) -- step 4
-- step 4
-- Matching element
step gst (Just sa, Just sb, Nothing, Just _, Just b) = do
r1 <- stepa gst sa
return $ case r1 of
Yield a' sa' -> do
if a' == b -- match with prev a
then Yield a' (Just sa', Just sb, Nothing, Just b, Just b) --step 1
else Skip (Just sa', Just sb, Just a', Nothing, Nothing)
Skip sa' -> Skip (Just sa', Just sb, Nothing, Nothing, Nothing)
Stop -> Stop
step _ (_, _, _, _, _) = return Stop
GT -> Skip (sa, sb, Just a, Nothing) -- step 2
LT -> Skip (sa, sb, Nothing, Just b) -- step 1
EQ -> Yield a (sa, sb, Nothing, Just b) -- step 1
------------------------------------------------------------------------------
-- Combine N Streams - unfoldMany

View File

@ -101,7 +101,7 @@ extra-source-files:
test/Streamly/Test/Data/Array/Prim/Pinned.hs
test/Streamly/Test/Data/Array/Foreign.hs
test/Streamly/Test/Data/Array/Stream/Foreign.hs
test/Streamly/Test/Data/Parser/ParserD.hs
test/Streamly/Test/Data/Parser/ParserD.hs
test/Streamly/Test/FileSystem/Event.hs
test/Streamly/Test/FileSystem/Event/Common.hs
test/Streamly/Test/FileSystem/Event/Darwin.hs

View File

@ -1,8 +1,8 @@
module Main (main)
where
module Main (main) where
import Data.List (elem, intersect, nub, sort)
import Data.Maybe (isNothing)
import Streamly.Prelude (SerialT)
import Test.QuickCheck
( Gen
, Property
@ -169,8 +169,16 @@ joinLeftMap =
let v2 = joinLeftList ls0 ls1
assert (v1 == v2)
intersectBy :: Property
intersectBy =
intersectBy ::
([Int] -> [Int])
-> ( (Int -> Int -> a)
-> SerialT IO Int
-> SerialT IO Int
-> SerialT IO Int
)
-> (Int -> Int -> a)
-> Property
intersectBy _srt intersectFunc cmp =
forAll (listOf (chooseInt (min_value, max_value))) $ \ls0 ->
forAll (listOf (chooseInt (min_value, max_value))) $ \ls1 ->
monadicIO $ action (sort ls0) (sort ls1)
@ -181,33 +189,17 @@ intersectBy =
v1 <-
run
$ S.toList
$ Top.intersectBy
(==)
$ intersectFunc
cmp
(S.fromList ls0)
(S.fromList ls1)
let v2 = intersect ls0 ls1
assert (v1 == sort v2)
intersectBySorted :: Property
intersectBySorted =
forAll (listOf (chooseInt (min_value, max_value))) $ \ls0 ->
forAll (listOf (chooseInt (min_value, max_value))) $ \ls1 ->
monadicIO $ action (sort ls0) (sort ls1)
where
action ls0 ls1 = do
v1 <-
run
$ S.toList
$ Top.intersectBySorted
compare
(S.fromList ls0)
(S.fromList ls1)
let v2 = intersect ls0 ls1
let v2 = ls0 `intersect` ls1
assert (v1 == sort v2)
-------------------------------------------------------------------------------
-- Main
-------------------------------------------------------------------------------
moduleName :: String
moduleName = "Prelude.Top"
@ -215,7 +207,6 @@ main :: IO ()
main = hspec $ do
describe moduleName $ do
-- Joins
prop "joinInner" Main.joinInner
prop "joinInnerMap" Main.joinInnerMap
-- XXX currently API is broken https://github.com/composewell/streamly/issues/1032
@ -224,5 +215,7 @@ main = hspec $ do
prop "joinLeft" Main.joinLeft
prop "joinLeftMap" Main.joinLeftMap
-- intersect
prop "intersectBy" Main.intersectBy
prop "intersectBySorted" Main.intersectBySorted
-- XXX currently API is broken https://github.com/composewell/streamly/issues/1471
--prop "intersectBy" (intersectBy id Top.intersectBy (==))
prop "intersectBySorted"
(intersectBy sort Top.intersectBySorted compare)