Implement partitionBy

This commit is contained in:
Ranjeet Kumar Ranjan 2021-12-27 21:09:53 +05:30 committed by Harendra Kumar
parent 37d3c17e05
commit 21e948a59f
5 changed files with 153 additions and 4 deletions

View File

@ -0,0 +1,96 @@
-- |
-- Module : Streamly.Benchmark.Data.Mut.Array
-- Copyright : (c) 2020 Composewell Technologies
--
-- License : BSD-3-Clause
-- Maintainer : streamly@composewell.com
{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
#ifdef __HADDOCK_VERSION__
#undef INSPECTION
#endif
#ifdef INSPECTION
{-# LANGUAGE TemplateHaskell #-}
{-# OPTIONS_GHC -fplugin Test.Inspection.Plugin #-}
#endif
module Main
(
main
) where
import Control.DeepSeq (NFData(..))
import System.Random (randomRIO)
import Prelude hiding ()
import qualified Streamly.Prelude as Stream
import qualified Streamly.Internal.Data.Fold as Fold
import qualified Streamly.Internal.Data.Array.Stream.Mut.Foreign as MArray
import Gauge hiding (env)
import Streamly.Prelude (SerialT, MonadAsync, IsStream)
import Streamly.Benchmark.Common
#ifdef INSPECTION
import Foreign.Storable (Storable)
import Streamly.Internal.Data.Stream.StreamD.Type (Step(..))
import Test.Inspection
#endif
-------------------------------------------------------------------------------
-- Utilities
-------------------------------------------------------------------------------
{-# INLINE sourceUnfoldrM #-}
sourceUnfoldrM :: (IsStream t, MonadAsync m) => Int -> Int -> t m Int
sourceUnfoldrM value n = Stream.unfoldrM step n
where
step cnt =
if cnt > n + value
then return Nothing
else return (Just (cnt, cnt + 1))
{-# INLINE benchIO #-}
benchIO
:: NFData b
=> String -> (Int -> t IO a) -> (t IO a -> IO b) -> Benchmark
benchIO name src sink =
bench name $ nfIO $ randomRIO (1,1) >>= sink . src
{-# INLINE partitionBy #-}
partitionBy :: Int -> SerialT IO (MArray.Array Int) -> IO ()
partitionBy value s = Stream.fold Fold.drain $ do
a <- s
MArray.partitionBy ( < value) a
o_n_space_serial_marray ::
Int -> [MArray.Array Int] -> [Benchmark]
o_n_space_serial_marray bound arrays =
[ benchIO "partitionBy" (\_ -> Stream.fromList arrays)
$ partitionBy bound
]
-------------------------------------------------------------------------------
-- Driver
-------------------------------------------------------------------------------
moduleName :: String
moduleName = "Data.Mut.Array"
main :: IO ()
main = do
runWithCLIOptsEnv defaultStreamSize alloc allBenchmarks
where
alloc value =
Stream.toList $ MArray.arraysOf value $ sourceUnfoldrM value 0
allBenchmarks arrays value =
[ bgroup (o_1_space_prefix moduleName) $
o_n_space_serial_marray value arrays
]

View File

@ -483,6 +483,17 @@ benchmark Data.Array.Foreign
if flag(limit-build-mem)
ghc-options: +RTS -M1000M -RTS
benchmark Data.Mut.Array
import: bench-options
type: exitcode-stdio-1.0
hs-source-dirs: Streamly/Benchmark/Data/Mut
main-is: Array.hs
if impl(ghcjs)
buildable: False
else
buildable: True
build-depends: exceptions >= 0.8 && < 0.11
-------------------------------------------------------------------------------
-- Array Stream Benchmarks
-------------------------------------------------------------------------------

View File

@ -1206,10 +1206,49 @@ permute = undefined
-- first half retains values where the predicate is 'False' and the second half
-- retains values where the predicate is 'True'.
--
-- /Unimplemented/
-- /Pre-release/
{-# INLINE partitionBy #-}
partitionBy :: (a -> Bool) -> Array a -> m (Array a, Array a)
partitionBy = undefined
partitionBy :: forall m a. (MonadIO m, Storable a)
=> (a -> Bool) -> Array a -> m (Array a, Array a)
partitionBy f arr = do
let low = 0
high = length arr - 1
swap low high arr
where
findL low = do
if length arr == low
then return low
else do
fw <- getIndex arr low
if not (f fw)
then findL (low + 1)
else return low
findR high = do
fw <- getIndex arr high
if f fw && high > 0
then findR (high - 1)
else return high
swap low high arr0 = do
if low < high
then do
left <- findL low
right <- findR high
if left < right
then do
unsafeSwapIndices arr0 left right
swap (left + 1) (right - 1) arr0
else do
let al = getSlice 0 left arr0
ar = getSlice left (length arr0 - left) arr0
return (al, ar)
else do
let al = getSlice 0 low arr0
ar = getSlice low (length arr0 - low) arr0
return (al, ar)
-- | Shuffle corresponding elements from two arrays using a shuffle function.
-- If the shuffle function returns 'False' then do nothing otherwise swap the

View File

@ -21,6 +21,8 @@ module Streamly.Internal.Data.Array.Stream.Mut.Foreign
, compactLE
, compactEQ
, compactGE
, Array(..)
, partitionBy
)
where
@ -32,7 +34,7 @@ import Control.Monad (when)
import Control.Monad.Catch (MonadThrow)
import Data.Bifunctor (first)
import Foreign.Storable (Storable(..))
import Streamly.Internal.Data.Array.Foreign.Mut.Type (Array(..))
import Streamly.Internal.Data.Array.Foreign.Mut.Type (Array(..), partitionBy)
import Streamly.Internal.Data.Fold.Type (Fold(..))
import Streamly.Internal.Data.Stream.Serial (SerialT(..))
import Streamly.Internal.Data.Stream.IsStream.Type

View File

@ -63,6 +63,7 @@ extra-source-files:
benchmark/Streamly/Benchmark/Data/Array/Prim.hs
benchmark/Streamly/Benchmark/Data/Array/Prim/Pinned.hs
benchmark/Streamly/Benchmark/Data/Array/Stream/Foreign.hs
benchmark/Streamly/Benchmark/Data/Mut/Array.hs
benchmark/Streamly/Benchmark/Data/Parser/*.hs
benchmark/Streamly/Benchmark/Data/Stream/*.hs
benchmark/Streamly/Benchmark/FileSystem/*.hs