Pass the array size to deserialize for bound check

This commit is contained in:
Harendra Kumar 2023-08-21 11:33:59 +05:30
parent ad84b443b0
commit 7bc450d725
5 changed files with 65 additions and 41 deletions

View File

@ -304,14 +304,14 @@ pokeTimes val times = do
loopWith times poke arr val
{-# INLINE peek #-}
peek :: forall a. (Eq a, SERIALIZE_CLASS a) => a -> MutableByteArray -> IO ()
peek val arr = do
peek :: forall a. (Eq a, SERIALIZE_CLASS a) => (a, Int) -> MutableByteArray -> IO ()
#ifdef USE_UNBOX
(val1 :: a)
peek (val, _) arr = do
(val1 :: a) <- DESERIALIZE_OP 0 arr
#else
(_, val1 :: a)
peek (val, n) arr = do
(_, val1 :: a) <- DESERIALIZE_OP 0 arr n
#endif
<- DESERIALIZE_OP 0 arr
-- Ensure that we are actually constructing the type and using it. This
-- is important, otherwise the structure is created and discarded, the
-- cost of creation of the structure is not accounted. Otherwise we may
@ -328,7 +328,7 @@ peekTimes :: (Eq a, SERIALIZE_CLASS a) => Int -> a -> Int -> IO ()
peekTimes n val times = do
arr <- newBytes n
_ <- SERIALIZE_OP 0 arr val
loopWith times peek val arr
loopWith times peek (val, n) arr
{-# INLINE trip #-}
trip :: forall a. (Eq a, SERIALIZE_CLASS a) => a -> IO ()
@ -337,11 +337,10 @@ trip val = do
arr <- newBytes n
_ <- SERIALIZE_OP 0 arr val
#ifdef USE_UNBOX
val1
val1 <- DESERIALIZE_OP 0 arr
#else
(_, val1)
(_, val1) <- DESERIALIZE_OP 0 arr n
#endif
<- DESERIALIZE_OP 0 arr
-- Do not remove this, see the comments in peek.
if (val1 /= val)
then error "roundtrip: no match"

View File

@ -31,7 +31,6 @@ import Streamly.Internal.Data.Unbox
( MutableByteArray(..)
, PinnedState(..)
, Unbox
, sizeOfMutableByteArray
)
import Streamly.Internal.Data.Array.Type (Array(..))
import Streamly.Internal.System.IO (unsafeInlineIO)
@ -83,9 +82,9 @@ newtype Size a = Size (Int -> a -> Int) -- a left fold or Sum monoid
-- (Size f, Size g) ->
-- Size $ \acc obj ->
-- acc + f 0 (_obj1 obj) + g 0 (_obj2 obj)
-- deserialize i arr = do
-- (i1, x0) <- deserialize i arr
-- (i2, x1) <- deserialize i1 arr
-- deserialize i arr len = do
-- (i1, x0) <- deserialize i arr len
-- (i2, x1) <- deserialize i1 arr len
-- pure (i2, Object x0 x1)
-- serialize i arr (Object x0 x1) = do
-- i1 <- serialize i arr x0
@ -102,12 +101,16 @@ class Serialize a where
-- offset but that may require traversing the Haskell structure again to get
-- the size. Therefore, this is a performance optimization.
-- | Deserialize a value from the given byte-index in the array. Returns a
-- tuple of the next byte-index and the deserialized value.
deserialize :: Int -> MutableByteArray -> IO (Int, a)
-- | @deserialize offset array arrayLen@ deserializes a value from the
-- given byte-index in the array. Returns a tuple of the next byte-index
-- and the deserialized value.
deserialize :: Int -> MutableByteArray -> Int -> IO (Int, a)
-- | Write the serialized representation of the value in the array at the
-- given byte-index. Returns the next byte-index.
-- given byte-index. Returns the next byte-index. This is an unsafe
-- operation, the programmer must ensure that the array has enough space
-- available in the array to serialize the value as determined by the
-- @size@ operation.
serialize :: Int -> MutableByteArray -> a -> IO Int
--------------------------------------------------------------------------------
@ -132,12 +135,29 @@ checkBounds _label _off _arr = do
else return ()
#endif
-- Note: Instead of passing around the size parameter, we can use
-- (sizeOfMutableByteArray arr) for checking the array bound, but that turns
-- out to be more expensive.
--
-- Another way to optimize this is to avoid the check for fixed size
-- structures. For fixed size structures we can do a check at the top level and
-- then use checkless deserialization using the Unbox type class. That will
-- require ConstSize and VarSize constructors in size. The programmer can
-- bundle all const size fields in a newtype to make serialization faster. This
-- can speed up the computation of size when serializing and checking size when
-- deserialing.
--
-- For variable size non-recursive structures a separate size validation method
-- could be used to validate the size before deserializing. "validate" can also
-- be used to collpase multiple chunks of arrays coming from network into a
-- single array for deserializing. But that can also be done by framing the
-- serialized value with a size header.
--
{-# INLINE deserializeUnsafe #-}
deserializeUnsafe :: forall a. Unbox a => Int -> MutableByteArray -> IO (Int, a)
deserializeUnsafe off arr =
deserializeUnsafe :: forall a. Unbox a => Int -> MutableByteArray -> Int -> IO (Int, a)
deserializeUnsafe off arr sz =
let next = off + Unbox.sizeOf (Proxy :: Proxy a)
in do
sz <- sizeOfMutableByteArray arr
-- Keep likely path in the straight branch.
if (next <= sz)
then Unbox.peekByteIndex off arr >>= \val -> pure (next, val)
@ -162,7 +182,7 @@ instance Serialize _type where \
; {-# INLINE size #-} \
; size = Size (\acc _ -> acc + Unbox.sizeOf (Proxy :: Proxy _type)) \
; {-# INLINE deserialize #-} \
; deserialize off arr = deserializeUnsafe off arr :: IO (Int, _type) \
; deserialize off arr end = deserializeUnsafe off arr end :: IO (Int, _type) \
; {-# INLINE serialize #-} \
; serialize = \
serializeUnsafe :: Int -> MutableByteArray -> _type -> IO Int
@ -193,19 +213,19 @@ instance forall a. Serialize a => Serialize [a] where
Size f -> foldl' f (acc + (Unbox.sizeOf (Proxy :: Proxy Int))) xs
{-# INLINE deserialize #-}
deserialize off arr = do
deserialize off arr sz = do
len <- Unbox.peekByteIndex off arr :: IO Int
let off1 = off + Unbox.sizeOf (Proxy :: Proxy Int)
let
peekList f o i | i >= 3 = do
-- Unfold the loop three times
(o1, x1) <- deserialize o arr
(o2, x2) <- deserialize o1 arr
(o3, x3) <- deserialize o2 arr
(o1, x1) <- deserialize o arr sz
(o2, x2) <- deserialize o1 arr sz
(o3, x3) <- deserialize o2 arr sz
peekList (f . (\xs -> x1:x2:x3:xs)) o3 (i - 3)
peekList f o 0 = pure (o, f [])
peekList f o i = do
(o1, x) <- deserialize o arr
(o1, x) <- deserialize o arr sz
peekList (f . (x:)) o1 (i - 1)
peekList id off1 len
@ -250,8 +270,8 @@ pinnedEncode = encodeAs Pinned
decode :: Serialize a => Array Word8 -> a
decode arr@(Array {..}) = unsafeInlineIO $ do
let lenArr = Array.length arr
(off1, lenEncoding :: Int64) <- deserialize 0 arrContents
(off2, val) <- deserialize off1 arrContents
(off1, lenEncoding :: Int64) <- deserialize 0 arrContents lenArr
(off2, val) <- deserialize off1 arrContents lenArr
assertM(fromIntegral lenEncoding + off1 == off2)
assertM(lenArr == off2)
pure val

View File

@ -53,6 +53,9 @@ _tag = mkName "tag"
_initialOffset :: Name
_initialOffset = mkName "initialOffset"
_endOffset :: Name
_endOffset = mkName "endOffset"
_val :: Name
_val = mkName "val"
@ -194,7 +197,7 @@ mkDeserializeExprOne (DataCon cname _ _ fields) =
makeBind i =
bindS
(tupP [varP (makeI (i + 1)), varP (makeA i)])
[|deserialize $(varE (makeI i)) $(varE _arr)|]
[|deserialize $(varE (makeI i)) $(varE _arr) $(varE _endOffset)|]
mkDeserializeExpr :: Type -> [DataCon] -> Q Exp
@ -217,7 +220,7 @@ mkDeserializeExpr headTy cons =
doE
[ bindS
(tupP [varP (mkName "i0"), varP _tag])
[|deserialize $(varE _initialOffset) $(varE _arr)|]
[|deserialize $(varE _initialOffset) $(varE _arr) $(varE _endOffset)|]
, noBindS
(caseE
(sigE (varE _tag) (conT tagType))
@ -351,8 +354,8 @@ deriveSerializeInternal preds headTy cons = do
'deserialize
[ Clause
(if isUnitType cons
then [VarP _initialOffset, WildP]
else [VarP _initialOffset, VarP _arr])
then [VarP _initialOffset, WildP, WildP]
else [VarP _initialOffset, VarP _arr, VarP _endOffset])
(NormalB peekMethod)
[]
]

View File

@ -68,7 +68,7 @@ roundtrip val = do
arr <- newBytes sz
off1 <- Serialize.serialize 0 arr val
(off2, val2) <- Serialize.deserialize 0 arr
(off2, val2) <- Serialize.deserialize 0 arr sz
val2 `shouldBe` val
off2 `shouldBe` off1

View File

@ -64,13 +64,13 @@ import Test.Hspec as H
#define MODULE_NAME "Data.Serialize.Deriving.TH"
#define DERIVE_UNBOX(typ) $(deriveSerialize ''typ)
#define PEEK(i, arr) fmap snd (deserialize i arr)
#define PEEK(i, arr, sz) fmap snd (deserialize i arr sz)
#define POKE(i, arr, val) void (serialize i arr val)
#define TYPE_CLASS Serialize
#else
#define PEEK(i, arr) peekByteIndex i arr
#define PEEK(i, arr, sz) peekByteIndex i arr
#define POKE(i, arr, val) pokeByteIndex i arr val
#define TYPE_CLASS Unbox
@ -189,14 +189,15 @@ testSerialization ::
=> a
-> IO ()
testSerialization val = do
arr <- newBytes
let len =
#ifdef USE_SERIALIZE
(variableSizeOf val)
#else
(sizeOf (Proxy :: Proxy a))
#endif
arr <- newBytes len
POKE(0, arr, val)
PEEK(0, arr) `shouldReturn` val
PEEK(0, arr, len) `shouldReturn` val
testGenericConsistency ::
forall a.
@ -216,12 +217,13 @@ testGenericConsistency ::
testGenericConsistency val = do
-- Test the generic sizeOf
let len =
#ifdef USE_SERIALIZE
variableSizeOf val
variableSizeOf val
#else
sizeOf (Proxy :: Proxy a)
sizeOf (Proxy :: Proxy a)
#endif
`shouldBe` genericSizeOf (Proxy :: Proxy a)
len `shouldBe` genericSizeOf (Proxy :: Proxy a)
-- Test the serialization and deserialization
arr <- newBytes (sizeOf (Proxy :: Proxy a))
@ -230,7 +232,7 @@ testGenericConsistency val = do
genericPeekByteIndex arr 0 `shouldReturn` val
genericPokeByteIndex arr 0 val
PEEK(0, arr) `shouldReturn` val
PEEK(0, arr, len) `shouldReturn` val
#ifndef USE_SERIALIZE