mirror of
https://github.com/ilyakooo0/streamly.git
synced 2024-10-26 09:59:48 +03:00
Pass the array size to deserialize for bound check
This commit is contained in:
parent
ad84b443b0
commit
7bc450d725
@ -304,14 +304,14 @@ pokeTimes val times = do
|
||||
loopWith times poke arr val
|
||||
|
||||
{-# INLINE peek #-}
|
||||
peek :: forall a. (Eq a, SERIALIZE_CLASS a) => a -> MutableByteArray -> IO ()
|
||||
peek val arr = do
|
||||
peek :: forall a. (Eq a, SERIALIZE_CLASS a) => (a, Int) -> MutableByteArray -> IO ()
|
||||
#ifdef USE_UNBOX
|
||||
(val1 :: a)
|
||||
peek (val, _) arr = do
|
||||
(val1 :: a) <- DESERIALIZE_OP 0 arr
|
||||
#else
|
||||
(_, val1 :: a)
|
||||
peek (val, n) arr = do
|
||||
(_, val1 :: a) <- DESERIALIZE_OP 0 arr n
|
||||
#endif
|
||||
<- DESERIALIZE_OP 0 arr
|
||||
-- Ensure that we are actually constructing the type and using it. This
|
||||
-- is important, otherwise the structure is created and discarded, the
|
||||
-- cost of creation of the structure is not accounted. Otherwise we may
|
||||
@ -328,7 +328,7 @@ peekTimes :: (Eq a, SERIALIZE_CLASS a) => Int -> a -> Int -> IO ()
|
||||
peekTimes n val times = do
|
||||
arr <- newBytes n
|
||||
_ <- SERIALIZE_OP 0 arr val
|
||||
loopWith times peek val arr
|
||||
loopWith times peek (val, n) arr
|
||||
|
||||
{-# INLINE trip #-}
|
||||
trip :: forall a. (Eq a, SERIALIZE_CLASS a) => a -> IO ()
|
||||
@ -337,11 +337,10 @@ trip val = do
|
||||
arr <- newBytes n
|
||||
_ <- SERIALIZE_OP 0 arr val
|
||||
#ifdef USE_UNBOX
|
||||
val1
|
||||
val1 <- DESERIALIZE_OP 0 arr
|
||||
#else
|
||||
(_, val1)
|
||||
(_, val1) <- DESERIALIZE_OP 0 arr n
|
||||
#endif
|
||||
<- DESERIALIZE_OP 0 arr
|
||||
-- Do not remove this, see the comments in peek.
|
||||
if (val1 /= val)
|
||||
then error "roundtrip: no match"
|
||||
|
@ -31,7 +31,6 @@ import Streamly.Internal.Data.Unbox
|
||||
( MutableByteArray(..)
|
||||
, PinnedState(..)
|
||||
, Unbox
|
||||
, sizeOfMutableByteArray
|
||||
)
|
||||
import Streamly.Internal.Data.Array.Type (Array(..))
|
||||
import Streamly.Internal.System.IO (unsafeInlineIO)
|
||||
@ -83,9 +82,9 @@ newtype Size a = Size (Int -> a -> Int) -- a left fold or Sum monoid
|
||||
-- (Size f, Size g) ->
|
||||
-- Size $ \acc obj ->
|
||||
-- acc + f 0 (_obj1 obj) + g 0 (_obj2 obj)
|
||||
-- deserialize i arr = do
|
||||
-- (i1, x0) <- deserialize i arr
|
||||
-- (i2, x1) <- deserialize i1 arr
|
||||
-- deserialize i arr len = do
|
||||
-- (i1, x0) <- deserialize i arr len
|
||||
-- (i2, x1) <- deserialize i1 arr len
|
||||
-- pure (i2, Object x0 x1)
|
||||
-- serialize i arr (Object x0 x1) = do
|
||||
-- i1 <- serialize i arr x0
|
||||
@ -102,12 +101,16 @@ class Serialize a where
|
||||
-- offset but that may require traversing the Haskell structure again to get
|
||||
-- the size. Therefore, this is a performance optimization.
|
||||
|
||||
-- | Deserialize a value from the given byte-index in the array. Returns a
|
||||
-- tuple of the next byte-index and the deserialized value.
|
||||
deserialize :: Int -> MutableByteArray -> IO (Int, a)
|
||||
-- | @deserialize offset array arrayLen@ deserializes a value from the
|
||||
-- given byte-index in the array. Returns a tuple of the next byte-index
|
||||
-- and the deserialized value.
|
||||
deserialize :: Int -> MutableByteArray -> Int -> IO (Int, a)
|
||||
|
||||
-- | Write the serialized representation of the value in the array at the
|
||||
-- given byte-index. Returns the next byte-index.
|
||||
-- given byte-index. Returns the next byte-index. This is an unsafe
|
||||
-- operation, the programmer must ensure that the array has enough space
|
||||
-- available in the array to serialize the value as determined by the
|
||||
-- @size@ operation.
|
||||
serialize :: Int -> MutableByteArray -> a -> IO Int
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
@ -132,12 +135,29 @@ checkBounds _label _off _arr = do
|
||||
else return ()
|
||||
#endif
|
||||
|
||||
-- Note: Instead of passing around the size parameter, we can use
|
||||
-- (sizeOfMutableByteArray arr) for checking the array bound, but that turns
|
||||
-- out to be more expensive.
|
||||
--
|
||||
-- Another way to optimize this is to avoid the check for fixed size
|
||||
-- structures. For fixed size structures we can do a check at the top level and
|
||||
-- then use checkless deserialization using the Unbox type class. That will
|
||||
-- require ConstSize and VarSize constructors in size. The programmer can
|
||||
-- bundle all const size fields in a newtype to make serialization faster. This
|
||||
-- can speed up the computation of size when serializing and checking size when
|
||||
-- deserialing.
|
||||
--
|
||||
-- For variable size non-recursive structures a separate size validation method
|
||||
-- could be used to validate the size before deserializing. "validate" can also
|
||||
-- be used to collpase multiple chunks of arrays coming from network into a
|
||||
-- single array for deserializing. But that can also be done by framing the
|
||||
-- serialized value with a size header.
|
||||
--
|
||||
{-# INLINE deserializeUnsafe #-}
|
||||
deserializeUnsafe :: forall a. Unbox a => Int -> MutableByteArray -> IO (Int, a)
|
||||
deserializeUnsafe off arr =
|
||||
deserializeUnsafe :: forall a. Unbox a => Int -> MutableByteArray -> Int -> IO (Int, a)
|
||||
deserializeUnsafe off arr sz =
|
||||
let next = off + Unbox.sizeOf (Proxy :: Proxy a)
|
||||
in do
|
||||
sz <- sizeOfMutableByteArray arr
|
||||
-- Keep likely path in the straight branch.
|
||||
if (next <= sz)
|
||||
then Unbox.peekByteIndex off arr >>= \val -> pure (next, val)
|
||||
@ -162,7 +182,7 @@ instance Serialize _type where \
|
||||
; {-# INLINE size #-} \
|
||||
; size = Size (\acc _ -> acc + Unbox.sizeOf (Proxy :: Proxy _type)) \
|
||||
; {-# INLINE deserialize #-} \
|
||||
; deserialize off arr = deserializeUnsafe off arr :: IO (Int, _type) \
|
||||
; deserialize off arr end = deserializeUnsafe off arr end :: IO (Int, _type) \
|
||||
; {-# INLINE serialize #-} \
|
||||
; serialize = \
|
||||
serializeUnsafe :: Int -> MutableByteArray -> _type -> IO Int
|
||||
@ -193,19 +213,19 @@ instance forall a. Serialize a => Serialize [a] where
|
||||
Size f -> foldl' f (acc + (Unbox.sizeOf (Proxy :: Proxy Int))) xs
|
||||
|
||||
{-# INLINE deserialize #-}
|
||||
deserialize off arr = do
|
||||
deserialize off arr sz = do
|
||||
len <- Unbox.peekByteIndex off arr :: IO Int
|
||||
let off1 = off + Unbox.sizeOf (Proxy :: Proxy Int)
|
||||
let
|
||||
peekList f o i | i >= 3 = do
|
||||
-- Unfold the loop three times
|
||||
(o1, x1) <- deserialize o arr
|
||||
(o2, x2) <- deserialize o1 arr
|
||||
(o3, x3) <- deserialize o2 arr
|
||||
(o1, x1) <- deserialize o arr sz
|
||||
(o2, x2) <- deserialize o1 arr sz
|
||||
(o3, x3) <- deserialize o2 arr sz
|
||||
peekList (f . (\xs -> x1:x2:x3:xs)) o3 (i - 3)
|
||||
peekList f o 0 = pure (o, f [])
|
||||
peekList f o i = do
|
||||
(o1, x) <- deserialize o arr
|
||||
(o1, x) <- deserialize o arr sz
|
||||
peekList (f . (x:)) o1 (i - 1)
|
||||
peekList id off1 len
|
||||
|
||||
@ -250,8 +270,8 @@ pinnedEncode = encodeAs Pinned
|
||||
decode :: Serialize a => Array Word8 -> a
|
||||
decode arr@(Array {..}) = unsafeInlineIO $ do
|
||||
let lenArr = Array.length arr
|
||||
(off1, lenEncoding :: Int64) <- deserialize 0 arrContents
|
||||
(off2, val) <- deserialize off1 arrContents
|
||||
(off1, lenEncoding :: Int64) <- deserialize 0 arrContents lenArr
|
||||
(off2, val) <- deserialize off1 arrContents lenArr
|
||||
assertM(fromIntegral lenEncoding + off1 == off2)
|
||||
assertM(lenArr == off2)
|
||||
pure val
|
||||
|
@ -53,6 +53,9 @@ _tag = mkName "tag"
|
||||
_initialOffset :: Name
|
||||
_initialOffset = mkName "initialOffset"
|
||||
|
||||
_endOffset :: Name
|
||||
_endOffset = mkName "endOffset"
|
||||
|
||||
_val :: Name
|
||||
_val = mkName "val"
|
||||
|
||||
@ -194,7 +197,7 @@ mkDeserializeExprOne (DataCon cname _ _ fields) =
|
||||
makeBind i =
|
||||
bindS
|
||||
(tupP [varP (makeI (i + 1)), varP (makeA i)])
|
||||
[|deserialize $(varE (makeI i)) $(varE _arr)|]
|
||||
[|deserialize $(varE (makeI i)) $(varE _arr) $(varE _endOffset)|]
|
||||
|
||||
|
||||
mkDeserializeExpr :: Type -> [DataCon] -> Q Exp
|
||||
@ -217,7 +220,7 @@ mkDeserializeExpr headTy cons =
|
||||
doE
|
||||
[ bindS
|
||||
(tupP [varP (mkName "i0"), varP _tag])
|
||||
[|deserialize $(varE _initialOffset) $(varE _arr)|]
|
||||
[|deserialize $(varE _initialOffset) $(varE _arr) $(varE _endOffset)|]
|
||||
, noBindS
|
||||
(caseE
|
||||
(sigE (varE _tag) (conT tagType))
|
||||
@ -351,8 +354,8 @@ deriveSerializeInternal preds headTy cons = do
|
||||
'deserialize
|
||||
[ Clause
|
||||
(if isUnitType cons
|
||||
then [VarP _initialOffset, WildP]
|
||||
else [VarP _initialOffset, VarP _arr])
|
||||
then [VarP _initialOffset, WildP, WildP]
|
||||
else [VarP _initialOffset, VarP _arr, VarP _endOffset])
|
||||
(NormalB peekMethod)
|
||||
[]
|
||||
]
|
||||
|
@ -68,7 +68,7 @@ roundtrip val = do
|
||||
arr <- newBytes sz
|
||||
|
||||
off1 <- Serialize.serialize 0 arr val
|
||||
(off2, val2) <- Serialize.deserialize 0 arr
|
||||
(off2, val2) <- Serialize.deserialize 0 arr sz
|
||||
val2 `shouldBe` val
|
||||
off2 `shouldBe` off1
|
||||
|
||||
|
@ -64,13 +64,13 @@ import Test.Hspec as H
|
||||
|
||||
#define MODULE_NAME "Data.Serialize.Deriving.TH"
|
||||
#define DERIVE_UNBOX(typ) $(deriveSerialize ''typ)
|
||||
#define PEEK(i, arr) fmap snd (deserialize i arr)
|
||||
#define PEEK(i, arr, sz) fmap snd (deserialize i arr sz)
|
||||
#define POKE(i, arr, val) void (serialize i arr val)
|
||||
#define TYPE_CLASS Serialize
|
||||
|
||||
#else
|
||||
|
||||
#define PEEK(i, arr) peekByteIndex i arr
|
||||
#define PEEK(i, arr, sz) peekByteIndex i arr
|
||||
#define POKE(i, arr, val) pokeByteIndex i arr val
|
||||
#define TYPE_CLASS Unbox
|
||||
|
||||
@ -189,14 +189,15 @@ testSerialization ::
|
||||
=> a
|
||||
-> IO ()
|
||||
testSerialization val = do
|
||||
arr <- newBytes
|
||||
let len =
|
||||
#ifdef USE_SERIALIZE
|
||||
(variableSizeOf val)
|
||||
#else
|
||||
(sizeOf (Proxy :: Proxy a))
|
||||
#endif
|
||||
arr <- newBytes len
|
||||
POKE(0, arr, val)
|
||||
PEEK(0, arr) `shouldReturn` val
|
||||
PEEK(0, arr, len) `shouldReturn` val
|
||||
|
||||
testGenericConsistency ::
|
||||
forall a.
|
||||
@ -216,12 +217,13 @@ testGenericConsistency ::
|
||||
testGenericConsistency val = do
|
||||
|
||||
-- Test the generic sizeOf
|
||||
let len =
|
||||
#ifdef USE_SERIALIZE
|
||||
variableSizeOf val
|
||||
variableSizeOf val
|
||||
#else
|
||||
sizeOf (Proxy :: Proxy a)
|
||||
sizeOf (Proxy :: Proxy a)
|
||||
#endif
|
||||
`shouldBe` genericSizeOf (Proxy :: Proxy a)
|
||||
len `shouldBe` genericSizeOf (Proxy :: Proxy a)
|
||||
|
||||
-- Test the serialization and deserialization
|
||||
arr <- newBytes (sizeOf (Proxy :: Proxy a))
|
||||
@ -230,7 +232,7 @@ testGenericConsistency val = do
|
||||
genericPeekByteIndex arr 0 `shouldReturn` val
|
||||
|
||||
genericPokeByteIndex arr 0 val
|
||||
PEEK(0, arr) `shouldReturn` val
|
||||
PEEK(0, arr, len) `shouldReturn` val
|
||||
|
||||
|
||||
#ifndef USE_SERIALIZE
|
||||
|
Loading…
Reference in New Issue
Block a user