mirror of
https://github.com/ilyakooo0/streamly.git
synced 2024-09-17 11:37:20 +03:00
Fix/simplify intersectBySorted
* Remove MonadIO and Eq constraints * Simplify implementation * Simplify tests * Fix formatting * Use longer benchmarks
This commit is contained in:
parent
160393c8e0
commit
ec20f5fdff
@ -417,57 +417,56 @@ o_n_space_monad value =
|
||||
-- Joining
|
||||
-------------------------------------------------------------------------------
|
||||
|
||||
toKvMap :: Int -> (Int, Int)
|
||||
toKvMap p = (p, p)
|
||||
toKv :: Int -> (Int, Int)
|
||||
toKv p = (p, p)
|
||||
|
||||
{-# INLINE joinWith #-}
|
||||
joinWith :: (S.MonadAsync m) =>
|
||||
((Int -> Int -> Bool) -> SerialT m Int -> SerialT m Int -> SerialT m b)
|
||||
-> Int
|
||||
-> Int
|
||||
-> Int
|
||||
-> m ()
|
||||
joinWith j val1 val2 i =
|
||||
S.drain $ j (==) (sourceUnfoldrM val1 i) (sourceUnfoldrM val2 i)
|
||||
joinWith j val i =
|
||||
S.drain $ j (==) (sourceUnfoldrM val i) (sourceUnfoldrM val (val `div` 2))
|
||||
|
||||
{-# INLINE joinMapWith #-}
|
||||
joinMapWith :: (S.MonadAsync m) =>
|
||||
(SerialT m (Int, Int) -> SerialT m (Int, Int) -> SerialT m b)
|
||||
-> Int
|
||||
-> Int
|
||||
-> Int
|
||||
-> m ()
|
||||
joinMapWith j val1 val2 i =
|
||||
joinMapWith j val i =
|
||||
S.drain
|
||||
$ j
|
||||
(fmap toKvMap (sourceUnfoldrM val1 i))
|
||||
(fmap toKvMap (sourceUnfoldrM val2 i))
|
||||
(fmap toKv (sourceUnfoldrM val i))
|
||||
(fmap toKv (sourceUnfoldrM val (val `div` 2)))
|
||||
|
||||
o_n_heap_buffering :: Int -> [Benchmark]
|
||||
o_n_heap_buffering value =
|
||||
[ bgroup "buffered"
|
||||
[
|
||||
benchIOSrc1 "joinInner"
|
||||
$ joinWith Internal.joinInner sqrtVal sqrtVal
|
||||
benchIOSrc1 "joinInner (sqrtVal)"
|
||||
$ joinWith Internal.joinInner sqrtVal
|
||||
, benchIOSrc1 "joinInnerMap"
|
||||
$ joinMapWith Internal.joinInnerMap sqrtVal sqrtVal
|
||||
, benchIOSrc1 "joinLeft"
|
||||
$ joinWith Internal.joinLeft sqrtVal sqrtVal
|
||||
$ joinMapWith Internal.joinInnerMap halfVal
|
||||
, benchIOSrc1 "joinLeft (sqrtVal)"
|
||||
$ joinWith Internal.joinLeft sqrtVal
|
||||
, benchIOSrc1 "joinLeftMap "
|
||||
$ joinMapWith Internal.joinLeftMap sqrtVal sqrtVal
|
||||
, benchIOSrc1 "joinOuter"
|
||||
$ joinWith Internal.joinOuter sqrtVal sqrtVal
|
||||
$ joinMapWith Internal.joinLeftMap halfVal
|
||||
, benchIOSrc1 "joinOuter (sqrtVal)"
|
||||
$ joinWith Internal.joinOuter sqrtVal
|
||||
, benchIOSrc1 "joinOuterMap"
|
||||
$ joinMapWith Internal.joinOuterMap sqrtVal sqrtVal
|
||||
, benchIOSrc1 "intersectBy"
|
||||
$ joinWith Internal.intersectBy sqrtVal sqrtVal
|
||||
$ joinMapWith Internal.joinOuterMap halfVal
|
||||
, benchIOSrc1 "intersectBy (sqrtVal)"
|
||||
$ joinWith Internal.intersectBy sqrtVal
|
||||
, benchIOSrc1 "intersectBySorted"
|
||||
$ joinMapWith Internal.intersectBySorted sqrtVal sqrtVal
|
||||
$ joinMapWith (Internal.intersectBySorted compare) halfVal
|
||||
]
|
||||
]
|
||||
|
||||
where
|
||||
|
||||
halfVal = value `div` 2
|
||||
sqrtVal = round $ sqrt (fromIntegral value :: Double)
|
||||
|
||||
-------------------------------------------------------------------------------
|
||||
|
@ -28,7 +28,7 @@ module Streamly.Internal.Data.Stream.IsStream.Top
|
||||
-- | These are not exactly set operations because streams are not
|
||||
-- necessarily sets, they may have duplicated elements.
|
||||
, intersectBy
|
||||
, intersectBySorted
|
||||
, intersectBySorted
|
||||
, differenceBy
|
||||
, mergeDifferenceBy
|
||||
, unionBy
|
||||
@ -65,7 +65,6 @@ import Streamly.Internal.Data.Stream.IsStream.Common (concatM)
|
||||
import Streamly.Internal.Data.Stream.IsStream.Type
|
||||
(IsStream(..), adapt, foldl', fromList)
|
||||
import Streamly.Internal.Data.Stream.Serial (SerialT)
|
||||
--import Streamly.Internal.Data.Stream.StreamD (fromStreamD, toStreamD)
|
||||
import Streamly.Internal.Data.Time.Units (NanoSecond64(..), toRelTime64)
|
||||
|
||||
import qualified Data.List as List
|
||||
@ -576,7 +575,7 @@ intersectBy eq s1 s2 =
|
||||
xs <- Stream.toListRev $ Stream.uniqBy eq $ adapt s2
|
||||
return $ Stream.filter (\x -> List.any (eq x) xs) s1
|
||||
|
||||
-- | Like 'intersectBy' but works only on sorted streams.
|
||||
-- | Like 'intersectBy' but works only on streams sorted in ascending order.
|
||||
--
|
||||
-- Space: O(1)
|
||||
--
|
||||
@ -584,10 +583,12 @@ intersectBy eq s1 s2 =
|
||||
--
|
||||
-- /Pre-release/
|
||||
{-# INLINE intersectBySorted #-}
|
||||
intersectBySorted :: (IsStream t, MonadIO m, Eq a) =>
|
||||
intersectBySorted :: (IsStream t, Monad m) =>
|
||||
(a -> a -> Ordering) -> t m a -> t m a -> t m a
|
||||
intersectBySorted eq s1 =
|
||||
IsStream.fromStreamD . StreamD.intersectBySorted eq (IsStream.toStreamD s1) . IsStream.toStreamD
|
||||
IsStream.fromStreamD
|
||||
. StreamD.intersectBySorted eq (IsStream.toStreamD s1)
|
||||
. IsStream.toStreamD
|
||||
|
||||
-- Roughly joinLeft s1 s2 = s1 `difference` s2 + s1 `intersection` s2
|
||||
|
||||
|
@ -484,57 +484,47 @@ mergeBy
|
||||
mergeBy cmp = mergeByM (\a b -> return $ cmp a b)
|
||||
|
||||
-------------------------------------------------------------------------------
|
||||
-- Intersection of sorted streams ---------------------------------------------
|
||||
-- Intersection of sorted streams
|
||||
-------------------------------------------------------------------------------
|
||||
|
||||
-- Assuming the streams are sorted in ascending order
|
||||
{-# INLINE_NORMAL intersectBySorted #-}
|
||||
intersectBySorted
|
||||
:: (MonadIO m, Eq a)
|
||||
intersectBySorted :: Monad m
|
||||
=> (a -> a -> Ordering) -> Stream m a -> Stream m a -> Stream m a
|
||||
intersectBySorted cmp (Stream stepa ta) (Stream stepb tb) =
|
||||
Stream step (Just ta, Just tb, Nothing, Nothing, Nothing)
|
||||
Stream step
|
||||
( ta -- left stream state
|
||||
, tb -- right stream state
|
||||
, Nothing -- left value
|
||||
, Nothing -- right value
|
||||
)
|
||||
|
||||
where
|
||||
{-# INLINE_LATE step #-}
|
||||
|
||||
-- step 1
|
||||
step gst (Just sa, sb, Nothing, b, Nothing) = do
|
||||
{-# INLINE_LATE step #-}
|
||||
-- step 1, fetch the first value
|
||||
step gst (sa, sb, Nothing, b) = do
|
||||
r <- stepa gst sa
|
||||
return $ case r of
|
||||
Yield a sa' -> Skip (Just sa', sb, Just a, b, Nothing)
|
||||
Skip sa' -> Skip (Just sa', sb, Nothing, b, Nothing)
|
||||
Yield a sa' -> Skip (sa', sb, Just a, b) -- step 2/3
|
||||
Skip sa' -> Skip (sa', sb, Nothing, b)
|
||||
Stop -> Stop
|
||||
|
||||
-- step 2
|
||||
step gst (sa, Just sb, a, Nothing, Nothing) = do
|
||||
-- step 2, fetch the second value
|
||||
step gst (sa, sb, a@(Just _), Nothing) = do
|
||||
r <- stepb gst sb
|
||||
return $ case r of
|
||||
Yield b sb' -> Skip (sa, Just sb', a, Just b, Nothing)
|
||||
Skip sb' -> Skip (sa, Just sb', a, Nothing, Nothing)
|
||||
Yield b sb' -> Skip (sa, sb', a, Just b) -- step 3
|
||||
Skip sb' -> Skip (sa, sb', a, Nothing)
|
||||
Stop -> Stop
|
||||
|
||||
-- step 3
|
||||
-- both the values are available compare it
|
||||
step _ (sa, sb, Just a, Just b, Nothing) = do
|
||||
-- step 3, compare the two values
|
||||
step _ (sa, sb, Just a, Just b) = do
|
||||
let res = cmp a b
|
||||
return $ case res of
|
||||
GT -> Skip (sa, sb, Just a, Nothing, Nothing)
|
||||
LT -> Skip (sa, sb, Nothing, Just b, Nothing)
|
||||
EQ -> Yield a (sa, sb, Nothing, Just a, Just b) -- step 4
|
||||
|
||||
-- step 4
|
||||
-- Matching element
|
||||
step gst (Just sa, Just sb, Nothing, Just _, Just b) = do
|
||||
r1 <- stepa gst sa
|
||||
return $ case r1 of
|
||||
Yield a' sa' -> do
|
||||
if a' == b -- match with prev a
|
||||
then Yield a' (Just sa', Just sb, Nothing, Just b, Just b) --step 1
|
||||
else Skip (Just sa', Just sb, Just a', Nothing, Nothing)
|
||||
|
||||
Skip sa' -> Skip (Just sa', Just sb, Nothing, Nothing, Nothing)
|
||||
Stop -> Stop
|
||||
|
||||
step _ (_, _, _, _, _) = return Stop
|
||||
GT -> Skip (sa, sb, Just a, Nothing) -- step 2
|
||||
LT -> Skip (sa, sb, Nothing, Just b) -- step 1
|
||||
EQ -> Yield a (sa, sb, Nothing, Just b) -- step 1
|
||||
|
||||
------------------------------------------------------------------------------
|
||||
-- Combine N Streams - unfoldMany
|
||||
|
@ -101,7 +101,7 @@ extra-source-files:
|
||||
test/Streamly/Test/Data/Array/Prim/Pinned.hs
|
||||
test/Streamly/Test/Data/Array/Foreign.hs
|
||||
test/Streamly/Test/Data/Array/Stream/Foreign.hs
|
||||
test/Streamly/Test/Data/Parser/ParserD.hs
|
||||
test/Streamly/Test/Data/Parser/ParserD.hs
|
||||
test/Streamly/Test/FileSystem/Event.hs
|
||||
test/Streamly/Test/FileSystem/Event/Common.hs
|
||||
test/Streamly/Test/FileSystem/Event/Darwin.hs
|
||||
|
@ -1,8 +1,8 @@
|
||||
module Main (main)
|
||||
where
|
||||
module Main (main) where
|
||||
|
||||
import Data.List (elem, intersect, nub, sort)
|
||||
import Data.Maybe (isNothing)
|
||||
import Streamly.Prelude (SerialT)
|
||||
import Test.QuickCheck
|
||||
( Gen
|
||||
, Property
|
||||
@ -169,8 +169,16 @@ joinLeftMap =
|
||||
let v2 = joinLeftList ls0 ls1
|
||||
assert (v1 == v2)
|
||||
|
||||
intersectBy :: Property
|
||||
intersectBy =
|
||||
intersectBy ::
|
||||
([Int] -> [Int])
|
||||
-> ( (Int -> Int -> a)
|
||||
-> SerialT IO Int
|
||||
-> SerialT IO Int
|
||||
-> SerialT IO Int
|
||||
)
|
||||
-> (Int -> Int -> a)
|
||||
-> Property
|
||||
intersectBy _srt intersectFunc cmp =
|
||||
forAll (listOf (chooseInt (min_value, max_value))) $ \ls0 ->
|
||||
forAll (listOf (chooseInt (min_value, max_value))) $ \ls1 ->
|
||||
monadicIO $ action (sort ls0) (sort ls1)
|
||||
@ -181,33 +189,17 @@ intersectBy =
|
||||
v1 <-
|
||||
run
|
||||
$ S.toList
|
||||
$ Top.intersectBy
|
||||
(==)
|
||||
$ intersectFunc
|
||||
cmp
|
||||
(S.fromList ls0)
|
||||
(S.fromList ls1)
|
||||
let v2 = intersect ls0 ls1
|
||||
assert (v1 == sort v2)
|
||||
|
||||
intersectBySorted :: Property
|
||||
intersectBySorted =
|
||||
forAll (listOf (chooseInt (min_value, max_value))) $ \ls0 ->
|
||||
forAll (listOf (chooseInt (min_value, max_value))) $ \ls1 ->
|
||||
monadicIO $ action (sort ls0) (sort ls1)
|
||||
|
||||
where
|
||||
|
||||
action ls0 ls1 = do
|
||||
v1 <-
|
||||
run
|
||||
$ S.toList
|
||||
$ Top.intersectBySorted
|
||||
compare
|
||||
(S.fromList ls0)
|
||||
(S.fromList ls1)
|
||||
let v2 = intersect ls0 ls1
|
||||
let v2 = ls0 `intersect` ls1
|
||||
assert (v1 == sort v2)
|
||||
|
||||
-------------------------------------------------------------------------------
|
||||
-- Main
|
||||
-------------------------------------------------------------------------------
|
||||
|
||||
moduleName :: String
|
||||
moduleName = "Prelude.Top"
|
||||
|
||||
@ -215,7 +207,6 @@ main :: IO ()
|
||||
main = hspec $ do
|
||||
describe moduleName $ do
|
||||
-- Joins
|
||||
|
||||
prop "joinInner" Main.joinInner
|
||||
prop "joinInnerMap" Main.joinInnerMap
|
||||
-- XXX currently API is broken https://github.com/composewell/streamly/issues/1032
|
||||
@ -224,5 +215,7 @@ main = hspec $ do
|
||||
prop "joinLeft" Main.joinLeft
|
||||
prop "joinLeftMap" Main.joinLeftMap
|
||||
-- intersect
|
||||
prop "intersectBy" Main.intersectBy
|
||||
prop "intersectBySorted" Main.intersectBySorted
|
||||
-- XXX currently API is broken https://github.com/composewell/streamly/issues/1471
|
||||
--prop "intersectBy" (intersectBy id Top.intersectBy (==))
|
||||
prop "intersectBySorted"
|
||||
(intersectBy sort Top.intersectBySorted compare)
|
||||
|
Loading…
Reference in New Issue
Block a user