Use unsafe peek and poke for better performance

This commit is contained in:
Harendra Kumar 2023-08-06 09:32:19 +05:30
parent fa8b0ab0be
commit 06ae33ed93
4 changed files with 149 additions and 54 deletions

View File

@ -1,5 +1,11 @@
{-# LANGUAGE TemplateHaskell #-}
#undef FUSION_CHECK
#ifdef FUSION_CHECK
{-# OPTIONS_GHC -ddump-simpl -ddump-to-file -dsuppress-all #-}
#endif
-- |
-- Module : Streamly.Benchmark.Data.Unbox
-- Copyright : (c) 2023 Composewell
@ -20,7 +26,9 @@ import GHC.Generics (Generic)
import System.Random (randomRIO)
import Streamly.Internal.Data.Unbox
#ifdef USE_TH
import Streamly.Internal.Data.Unbox.TH
#endif
import Gauge
import Streamly.Benchmark.Common
@ -29,30 +37,41 @@ import Streamly.Benchmark.Common
-- Types
-------------------------------------------------------------------------------
data CustomDT1 a b
data CustomDT1
= CDT1C1
| CDT1C2 a
| CDT1C3 a b
| CDT1C2 Int
| CDT1C3 Int Bool
deriving (Generic, Show, Eq)
type CustomDT1_ = CustomDT1 Int Bool
instance (Unbox a, Unbox b) => Unbox (CustomDT1 a b)
data CustomDT2 a b
= CDT2C1
| CDT2C2 a
| CDT2C3 a b
deriving (Show, Eq)
type CustomDT2_ = CustomDT2 Int Bool
$(deriveUnbox ''CustomDT2)
#ifndef USE_TH
instance Unbox CustomDT1
#else
$(deriveUnbox ''CustomDT1)
#endif
-------------------------------------------------------------------------------
-- Helpers
-------------------------------------------------------------------------------
serializeDeserializeTimes :: forall a. (Eq a, Unbox a) => a -> Int -> IO ()
serializeDeserializeTimes val times = do
{-# INLINE pokeTimes #-}
pokeTimes :: forall a. Unbox a => a -> Int -> IO ()
pokeTimes val times = do
arr <- newBytes (sizeOf (Proxy :: Proxy a))
replicateM_ times $ do
pokeByteIndex 0 arr val
{-# INLINE peekTimes #-}
peekTimes :: forall a. Unbox a => a -> Int -> IO ()
peekTimes val times = do
arr <- newBytes (sizeOf (Proxy :: Proxy a))
pokeByteIndex 0 arr val
replicateM_ times $ do
(_ :: a) <- peekByteIndex 0 arr
return ()
{-# INLINE roundtrip #-}
roundtrip :: forall a. (Eq a, Unbox a) => a -> Int -> IO ()
roundtrip val times = do
arr <- newBytes (sizeOf (Proxy :: Proxy a))
replicateM_ times $ do
pokeByteIndex 0 arr val
@ -68,30 +87,30 @@ benchSink name times f = bench name (nfIO (randomRIO (times, times) >>= f))
allBenchmarks :: Int -> [Benchmark]
allBenchmarks times =
[ benchSink
"serializeDeserializeTimes CDT1C1 (Generic)"
times
(serializeDeserializeTimes (CDT1C1 :: CustomDT1_))
, benchSink
"serializeDeserializeTimes CDT2C1 (TH)"
times
(serializeDeserializeTimes (CDT2C1 :: CustomDT2_))
, benchSink
"serializeDeserializeTimes CDT1C2 (Generic)"
times
(serializeDeserializeTimes ((CDT1C2 (5 :: Int)) :: CustomDT1_))
, benchSink
"serializeDeserializeTimes CDT2C2 (TH)"
times
(serializeDeserializeTimes ((CDT2C2 (5 :: Int)) :: CustomDT2_))
, benchSink
"serializeDeserializeTimes CDT1C3 (Generic)"
times
(serializeDeserializeTimes ((CDT1C3 (5 :: Int) True) :: CustomDT1_))
, benchSink
"serializeDeserializeTimes CDT2C3 (TH)"
times
(serializeDeserializeTimes ((CDT2C3 (5 :: Int) True) :: CustomDT2_))
[ bgroup "poke"
[ benchSink "C1" times
(pokeTimes (CDT1C1 :: CustomDT1))
, benchSink "C2" times
(pokeTimes ((CDT1C2 (5 :: Int)) :: CustomDT1))
, benchSink "C3" times
(pokeTimes ((CDT1C3 (5 :: Int) True) :: CustomDT1))
]
, bgroup "peek"
[ benchSink "C1" times
(peekTimes (CDT1C1 :: CustomDT1))
, benchSink "C2" times
(peekTimes ((CDT1C2 (5 :: Int)) :: CustomDT1))
, benchSink "C3" times
(peekTimes ((CDT1C3 (5 :: Int) True) :: CustomDT1))
]
, bgroup "roundtrip"
[ benchSink "C1" times
(roundtrip (CDT1C1 :: CustomDT1))
, benchSink "C2" times
(roundtrip ((CDT1C2 (5 :: Int)) :: CustomDT1))
, benchSink "C3" times
(roundtrip ((CDT1C3 (5 :: Int) True) :: CustomDT1))
]
]
-------------------------------------------------------------------------------
@ -99,4 +118,15 @@ allBenchmarks times =
-------------------------------------------------------------------------------
main :: IO ()
main = runWithCLIOpts defaultStreamSize allBenchmarks
main = do
#ifndef FUSION_CHECK
runWithCLIOpts defaultStreamSize allBenchmarks
#else
-- Enable FUSION_CHECK macro at the beginning of the file
-- Enable one benchmark below, and run the benchmark
-- Check the .dump-simpl output
let value = 100000
-- peekTimes ((CDT1C2 (5 :: Int)) :: CustomDT1) value
roundtrip ((CDT1C2 (5 :: Int)) :: CustomDT1) value
return ()
#endif

View File

@ -362,6 +362,13 @@ benchmark Data.Unbox
hs-source-dirs: .
main-is: Streamly/Benchmark/Data/Unbox.hs
benchmark Data.Unbox.Derive.TH
import: bench-options
type: exitcode-stdio-1.0
hs-source-dirs: .
cpp-options: -DUSE_TH
main-is: Streamly/Benchmark/Data/Unbox.hs
benchmark Data.Unfold
import: bench-options
type: exitcode-stdio-1.0

View File

@ -100,6 +100,7 @@ touch :: MutableByteArray -> IO ()
touch (MutableByteArray contents) =
IO $ \s -> case touch# contents s of s' -> (# s', () #)
{-
-- | Return the size of the array in bytes.
{-# INLINE sizeOfMutableByteArray #-}
sizeOfMutableByteArray :: MutableByteArray -> IO Int
@ -107,6 +108,7 @@ sizeOfMutableByteArray (MutableByteArray arr) =
IO $ \s ->
case getSizeofMutableByteArray# arr s of
(# s1, i #) -> (# s1, I# i #)
-}
--------------------------------------------------------------------------------
-- Creation
@ -282,7 +284,9 @@ unpin arr@(MutableByteArray marr#) =
-- The 'peekByteIndex' read operation converts a Haskell type from its unboxed
-- representation stored in a mutable byte array, while the 'pokeByteIndex'
-- write operation serializes a Haskell data type to its byte representation in
-- the mutable byte array.
-- the mutable byte array. These operations do not check the bounds of the
-- array, the user of the type class is expected to check the bounds before
-- peeking or poking.
--
-- Instances can be derived via Generics or Template Haskell. Note that the
-- data type must be non-recursive.
@ -326,10 +330,12 @@ unpin arr@(MutableByteArray marr#) =
-- instance Unbox Object where
-- sizeOf _ = 16
-- peekByteIndex i arr = do
-- -- Check the array bounds
-- x0 <- peekByteIndex i arr
-- x1 <- peekByteIndex (i + 8) arr
-- return $ Object x0 x1
-- pokeByteIndex i arr (Object x0 x1) = do
-- -- Check the array bounds
-- pokeByteIndex i arr x0
-- pokeByteIndex (i + 8) arr x1
-- :}
@ -587,6 +593,12 @@ readUnsafe = Peeker (Builder step)
step :: forall a. Unbox a => BoundedPtr -> IO (BoundedPtr, a)
step (BoundedPtr arr pos end) = do
let next = pos + sizeOf (Proxy :: Proxy a)
#ifdef DEBUG
when (next > end)
$ error $ "readUnsafe: reading beyond limit. next = "
++ show next
++ " end = " ++ show end
#endif
r <- peekByteIndex pos arr
return (BoundedPtr arr next end, r)
@ -600,7 +612,10 @@ read = Peeker (Builder step)
step :: forall a. Unbox a => BoundedPtr -> IO (BoundedPtr, a)
step (BoundedPtr arr pos end) = do
let next = pos + sizeOf (Proxy :: Proxy a)
when (next > end) $ error "peekObject reading beyond limit"
when (next > end)
$ error $ "read: reading beyond limit. next = "
++ show next
++ " end = " ++ show end
r <- peekByteIndex pos arr
return (BoundedPtr arr next end, r)
@ -614,10 +629,12 @@ skipByte = Peeker (Builder step)
step :: BoundedPtr -> IO (BoundedPtr, ())
step (BoundedPtr arr pos end) = do
let next = pos + 1
#ifdef DEBUG
when (next > end)
$ error $ "skipByte: reading beyond limit. next = "
++ show next
++ " end = " ++ show end
#endif
return (BoundedPtr arr next end, ())
{-# INLINE runPeeker #-}
@ -638,6 +655,12 @@ runPeeker (Peeker (Builder f)) ptr = fmap snd (f ptr)
pokeBoundedPtrUnsafe :: forall a. Unbox a => a -> BoundedPtr -> IO BoundedPtr
pokeBoundedPtrUnsafe a (BoundedPtr arr pos end) = do
let next = pos + sizeOf (Proxy :: Proxy a)
#ifdef DEBUG
when (next > end)
$ error $ "pokeBoundedPtrUnsafe: reading beyond limit. next = "
++ show next
++ " end = " ++ show end
#endif
pokeByteIndex pos arr a
return (BoundedPtr arr next end)
@ -669,7 +692,10 @@ type family ArityCheck (b :: Bool) :: Constraint where
-- Type constraint to restrict the sum type arity so that the constructor tag
-- can fit in a single byte.
type MaxArity256 n = ArityCheck (n <=? 255)
-- Note that Arity starts from 1 and constructor tags start from 0. So if max
-- arity is 256 then max constructor tag would be 255.
-- XXX Use variable length encoding to support more than 256 constructors.
type MaxArity256 n = ArityCheck (n <=? 256)
--------------------------------------------------------------------------------
-- Generic Deriving of Unbox instance
@ -735,8 +761,8 @@ instance (MaxArity256 (SumArity (f :+: g)), SizeOfRepSum f, SizeOfRepSum g) =>
-- The size of a sum type is the max of any of the constructor size.
-- sizeOfRepSum type class operation is used here instead of sizeOfRep so
-- that we add the constructor index byte only for the first time and avoid
-- including it for the subsequent sum constructors.
-- that we account the constructor index byte only for the first time and
-- avoid including it for the subsequent sum constructors.
{-# INLINE sizeOfRep #-}
sizeOfRep _ =
-- One byte for the constructor id and then the constructor value.
@ -753,6 +779,12 @@ instance (MaxArity256 (SumArity (f :+: g)), SizeOfRepSum f, SizeOfRepSum g) =>
-- elements to make the size as 1. Or we can disallow arrays with elements
-- having size 0.
--
-- Some examples:
--
-- data B = B -- one byte
-- data A = A B -- one byte
-- data X = X1 | X2 -- one byte (constructor tag only)
--
{-# INLINE genericSizeOf #-}
genericSizeOf :: forall a. (SizeOfRep (Rep a)) => Proxy a -> Int
genericSizeOf _ =
@ -772,7 +804,7 @@ instance PokeRep f => PokeRep (M1 i c f) where
instance Unbox a => PokeRep (K1 i a) where
{-# INLINE pokeRep #-}
pokeRep a = pokeBoundedPtr (unK1 a)
pokeRep a = pokeBoundedPtrUnsafe (unK1 a)
instance PokeRep V1 where
{-# INLINE pokeRep #-}
@ -795,8 +827,8 @@ class KnownNat n => PokeRepSum (n :: Nat) (f :: Type -> Type) where
instance (KnownNat n, PokeRep a) => PokeRepSum n (C1 c a) where
{-# INLINE pokeRepSum #-}
pokeRepSum _ x ptr = do
pokeBoundedPtr (fromInteger (natVal (Proxy :: Proxy n)) :: Word8) ptr
>>= pokeRep x
let tag = fromInteger (natVal (Proxy :: Proxy n)) :: Word8
pokeBoundedPtrUnsafe tag ptr >>= pokeRep x
instance (PokeRepSum n f, PokeRepSum (n + SumArity f) g)
=> PokeRepSum n (f :+: g) where
@ -827,14 +859,20 @@ genericPokeByteIndex :: (Generic a, PokeRep (Rep a)) =>
MutableByteArray -> Int -> a -> IO ()
genericPokeByteIndex arr index x = do
-- XXX Should we use unsafe poke?
#ifdef DEBUG
end <- sizeOfMutableByteArray arr
genericPokeObj x (BoundedPtr arr index end)
#else
genericPokeObj x (BoundedPtr arr index undefined)
#endif
--------------------------------------------------------------------------------
-- Generic peek
--------------------------------------------------------------------------------
class PeekRep (f :: Type -> Type) where
-- Like pokeRep, we can use the following signature instead of using Peeker
-- peekRep :: BoundedPtr -> IO (BoundedPtr, f a)
peekRep :: Peeker (f x)
instance PeekRep f => PeekRep (M1 i c f) where
@ -843,7 +881,7 @@ instance PeekRep f => PeekRep (M1 i c f) where
instance Unbox a => PeekRep (K1 i a) where
{-# INLINE peekRep #-}
peekRep = fmap K1 read
peekRep = fmap K1 readUnsafe
instance PeekRep V1 where
{-# INLINE peekRep #-}
@ -865,6 +903,11 @@ class KnownNat n => PeekRepSum (n :: Nat) (f :: Type -> Type) where
instance (KnownNat n, PeekRep a) => PeekRepSum n (C1 c a) where
{-# INLINE peekRepSum #-}
peekRepSum _ _ = peekRep
{-
-- These error checks are expensive, to avoid these
-- we validate the max value of the tag in peekRep.
-- XXX Add tests to cover all cases
peekRepSum _ tag
| tag == curTag = peekRep
| tag > curTag =
@ -876,6 +919,7 @@ instance (KnownNat n, PeekRep a) => PeekRepSum n (C1 c a) where
where
curTag = fromInteger (natVal (Proxy :: Proxy n))
-}
instance (PeekRepSum n f, PeekRepSum (n + SumArity f) g)
=> PeekRepSum n (f :+: g) where
@ -892,12 +936,21 @@ instance (PeekRepSum n f, PeekRepSum (n + SumArity f) g)
-------------------------------------------------------------------------------
instance (MaxArity256 (SumArity (f :+: g)), PeekRepSum 0 (f :+: g))
instance ( MaxArity256 (SumArity (f :+: g))
, KnownNat (SumArity (f :+: g))
, PeekRepSum 0 (f :+: g))
=> PeekRep (f :+: g) where
{-# INLINE peekRep #-}
peekRep = do
tag <- read
peekRepSum (Proxy :: Proxy 0) tag
tag :: Word8 <- readUnsafe
-- XXX test with 256 and more constructors
let arity :: Int =
fromInteger (natVal (Proxy :: Proxy (SumArity (f :+: g))))
when (fromIntegral tag >= arity)
$ error $ "peek: Tag " ++ show tag
++ " is greater than the max tag " ++ show (arity - 1)
++ " for the data type"
peekRepSum (Proxy :: Proxy 0) tag -- DataKinds
{-# INLINE genericPeeker #-}
genericPeeker :: (Generic a, PeekRep (Rep a)) => Peeker a
@ -912,5 +965,9 @@ genericPeekByteIndex :: (Generic a, PeekRep (Rep a)) =>
MutableByteArray -> Int -> IO a
genericPeekByteIndex arr index = do
-- XXX Should we use unsafe peek?
#ifdef DEBUG
end <- sizeOfMutableByteArray arr
genericPeekBoundedPtr (BoundedPtr arr index end)
#else
genericPeekBoundedPtr (BoundedPtr arr index undefined)
#endif

View File

@ -187,7 +187,8 @@ targets =
, ("Data.Fold.Window", [ "parser_grp", "infinite_grp" ])
, ("Data.Parser", [ "parser_grp", "infinite_grp" ])
, ("Data.Unbox", ["noBench"])
, ("Data.Unbox", [])
, ("Data.Unbox.Derive.TH", [])
, ("Data.Unfold", ["infinite_grp"])
, ("FileSystem.Handle", [])
, ("Unicode.Stream", [])