Add support for byte alignment and bit skip

This commit is contained in:
Sylvain HENRY 2015-03-06 20:40:36 +01:00
parent 4c202bbffb
commit 858e615896
5 changed files with 76 additions and 30 deletions

View File

@ -18,7 +18,6 @@ import Control.Applicative
import Data.Bits import Data.Bits
import Data.Word import Data.Word
import Foreign.Storable import Foreign.Storable
import System.Random
import Data.Traversable (traverse) import Data.Traversable (traverse)
import Data.Foldable (traverse_) import Data.Foldable (traverse_)
@ -262,8 +261,8 @@ prop_fail lbs errMsg0 = forAll (choose (0, 8 * L.length lbs)) $ \len ->
expectedBytesConsumed expectedBytesConsumed
| bits == 0 = bytes | bits == 0 = bytes
| otherwise = bytes + 1 | otherwise = bytes + 1
p = do getByteString (fromIntegral bytes) p = do _ <- getByteString (fromIntegral bytes)
getBits (fromIntegral bits) :: BitGet Word8 _ <- getBits (fromIntegral bits) :: BitGet Word8
fail errMsg0 fail errMsg0
r = runGetIncremental (runBitGet p) `pushChunks` lbs r = runGetIncremental (runBitGet p) `pushChunks` lbs
in case r of in case r of
@ -395,11 +394,6 @@ instance (Arbitrary (W a), Arbitrary (W b), Arbitrary (W c)) => Arbitrary (W (a,
arbitrary = ((W .) .) . (,,) <$> arbitraryW <*> arbitraryW <*> arbitraryW arbitrary = ((W .) .) . (,,) <$> arbitraryW <*> arbitraryW <*> arbitraryW
shrink (W (a,b,c)) = ((W .) .) . (,,) <$> shrinkW a <*> shrinkW b <*> shrinkW c shrink (W (a,b,c)) = ((W .) .) . (,,) <$> shrinkW a <*> shrinkW b <*> shrinkW c
integralRandomR :: (Integral a, RandomGen g) => (a,a) -> g -> (a,g)
integralRandomR (a,b) g = case randomR (fromIntegral a :: Integer,
fromIntegral b :: Integer) g of
(x,g) -> (fromIntegral x, g)
data Primitive data Primitive
= Bool Bool = Bool Bool
| W8 Int Word8 | W8 Int Word8
@ -408,6 +402,7 @@ data Primitive
| W64 Int Word64 | W64 Int Word64
| BS Int B.ByteString | BS Int B.ByteString
| LBS Int L.ByteString | LBS Int L.ByteString
| Skip Int
| IsEmpty | IsEmpty
deriving (Eq, Show) deriving (Eq, Show)
@ -426,6 +421,7 @@ instance Arbitrary Primitive where
, gen W16 , gen W16
, gen W32 , gen W32
, gen W64 , gen W64
, Skip <$> choose (0, 3000)
, do n <- choose (0,10) , do n <- choose (0,10)
cs <- vector n cs <- vector n
return (BS n (B.pack cs)) return (BS n (B.pack cs))
@ -442,6 +438,7 @@ instance Arbitrary Primitive where
W16 _ x -> snk W16 x W16 _ x -> snk W16 x
W32 _ x -> snk W32 x W32 _ x -> snk W32 x
W64 _ x -> snk W64 x W64 _ x -> snk W64 x
Skip x -> Skip <$> shrink x
BS _ bs -> let ws = B.unpack bs in map (\ws' -> BS (length ws') (B.pack ws')) (shrink ws) BS _ bs -> let ws = B.unpack bs in map (\ws' -> BS (length ws') (B.pack ws')) (shrink ws)
LBS _ lbs -> let ws = L.unpack lbs in map (\ws' -> LBS (length ws') (L.pack ws')) (shrink ws) LBS _ lbs -> let ws = L.unpack lbs in map (\ws' -> LBS (length ws') (L.pack ws')) (shrink ws)
IsEmpty -> [] IsEmpty -> []
@ -478,6 +475,7 @@ putPrimitive p =
W16 n x -> putWord16 n x W16 n x -> putWord16 n x
W32 n x -> putWord32 n x W32 n x -> putWord32 n x
W64 n x -> putWord64 n x W64 n x -> putWord64 n x
Skip n -> skipBits n
BS _ bs -> putByteString bs BS _ bs -> putByteString bs
LBS _ lbs -> mapM_ putByteString (L.toChunks lbs) LBS _ lbs -> mapM_ putByteString (L.toChunks lbs)
IsEmpty -> return () IsEmpty -> return ()
@ -490,10 +488,22 @@ getPrimitive p =
W16 n _ -> W16 n <$> getWord16 n W16 n _ -> W16 n <$> getWord16 n
W32 n _ -> W32 n <$> getWord32 n W32 n _ -> W32 n <$> getWord32 n
W64 n _ -> W64 n <$> getWord64 n W64 n _ -> W64 n <$> getWord64 n
Skip n -> skipBits n >> return (Skip n)
BS n _ -> BS n <$> getByteString n BS n _ -> BS n <$> getByteString n
LBS n _ -> LBS n <$> getLazyByteString n LBS n _ -> LBS n <$> getLazyByteString n
IsEmpty -> isEmpty >> return IsEmpty IsEmpty -> isEmpty >> return IsEmpty
getPrimitiveSize :: Primitive -> Int
getPrimitiveSize p = case p of
Bool _ -> 1
W8 n _ -> n
W16 n _ -> n
W32 n _ -> n
W64 n _ -> n
Skip n -> n
BS n _ -> n*8
LBS n _ -> n*8
IsEmpty -> 0
verifyProgram :: Int -> Program -> BitGet Bool verifyProgram :: Int -> Program -> BitGet Bool
verifyProgram totalLength ps0 = go 0 ps0 verifyProgram totalLength ps0 = go 0 ps0
@ -501,13 +511,6 @@ verifyProgram totalLength ps0 = go 0 ps0
go _ [] = return True go _ [] = return True
go pos (p:ps) = go pos (p:ps) =
case p of case p of
Bool x -> check x getBool >> go (pos+1) ps
W8 n x -> check x (getWord8 n) >> go (pos+n) ps
W16 n x -> check x (getWord16 n) >> go (pos+n) ps
W32 n x -> check x (getWord32 n) >> go (pos+n) ps
W64 n x -> check x (getWord64 n) >> go (pos+n) ps
BS n x -> check x (getByteString n) >> go (pos+(8*n)) ps
LBS n x -> check x (getLazyByteString n) >> go (pos+(8*n)) ps
IsEmpty -> do IsEmpty -> do
let expected = pos == totalLength let expected = pos == totalLength
actual <- isEmpty actual <- isEmpty
@ -515,6 +518,7 @@ verifyProgram totalLength ps0 = go 0 ps0
then go pos ps then go pos ps
else error $ "isEmpty returned wrong value, expected " else error $ "isEmpty returned wrong value, expected "
++ show expected ++ " but got " ++ show actual ++ show expected ++ " but got " ++ show actual
_ -> check p (getPrimitive p) >> go (pos + getPrimitiveSize p) ps
check x g = do check x g = do
y <- g y <- g
if x == y if x == y

View File

@ -0,0 +1,22 @@
-----------------------------------------------------------------------------
-- |
-- Module : Data.Binary.Bits.Get
-- Copyright : (c) Lennart Kolmodin 2010-2011
-- (c) Sylvain Henry 2015
-- License : BSD3-style (see LICENSE)
--
-- Maintainer : kolmodin@gmail.com
-- Stability : experimental
-- Portability : portable (should run where the package binary runs)
module Data.Binary.Bits.Alignment
( Alignable(..)
)
where
class Monad m => Alignable m where
-- | Skip the given number of bits
skipBits :: Int -> m ()
-- | Skip bits if necessary to align to the next byte
alignByte :: m ()

View File

@ -98,6 +98,7 @@ import Data.Binary.Get as B ( Get, getLazyByteString, isEmpty )
import Data.Binary.Get.Internal as B ( get, put, ensureN ) import Data.Binary.Get.Internal as B ( get, put, ensureN )
import Data.Binary.Bits.BitOrder import Data.Binary.Bits.BitOrder
import Data.Binary.Bits.Internal import Data.Binary.Bits.Internal
import Data.Binary.Bits.Alignment
import Data.ByteString as B import Data.ByteString as B
import qualified Data.ByteString.Lazy as L import qualified Data.ByteString.Lazy as L
@ -291,6 +292,20 @@ instance BitOrderable BitGet where
(S _ _ bo) <- getState (S _ _ bo) <- getState
return bo return bo
instance Alignable BitGet where
-- | Skip the given number of bits
skipBits n = do
ensureBits n
withState (incS n)
-- | Skip bits if necessary to align to the next byte
alignByte = do
(S _ o _) <- getState
when (o /= 0) $
skipBits (8-o)
-- | Run a 'BitGet' within the Binary packages 'Get' monad. If a byte has -- | Run a 'BitGet' within the Binary packages 'Get' monad. If a byte has
-- been partially consumed it will be discarded once 'runBitGet' is finished. -- been partially consumed it will be discarded once 'runBitGet' is finished.
runBitGet :: BitGet a -> Get a runBitGet :: BitGet a -> Get a
@ -338,20 +353,6 @@ ensureBits n = do
put B.empty put B.empty
return (S (bs`append`bs') o bo, ()) return (S (bs`append`bs') o bo, ())
-- | Skip the given number of bits
skipBits :: Int -> BitGet ()
skipBits n = do
ensureBits n
withState (incS n)
-- | Skip bits if necessary to align to the next byte
alignByte :: BitGet ()
alignByte = do
(S _ o _) <- getState
when (o /= 0) $
skipBits (8-o)
-- | Test whether all input has been consumed, i.e. there are no remaining -- | Test whether all input has been consumed, i.e. there are no remaining
-- undecoded bytes. -- undecoded bytes.
isEmpty :: BitGet Bool isEmpty :: BitGet Bool

View File

@ -41,11 +41,13 @@ import qualified Data.Binary.Put as Put
import Data.Binary.Put ( Put ) import Data.Binary.Put ( Put )
import Data.Binary.Bits.Internal import Data.Binary.Bits.Internal
import Data.Binary.Bits.BitOrder import Data.Binary.Bits.BitOrder
import Data.Binary.Bits.Alignment
import Data.ByteString as BS import Data.ByteString as BS
import Data.ByteString.Unsafe as BS import Data.ByteString.Unsafe as BS
import Control.Applicative import Control.Applicative
import Control.Monad (when)
import Data.Bits import Data.Bits
import Data.Monoid import Data.Monoid
import Data.Word import Data.Word
@ -178,6 +180,9 @@ flushIncomplete s@(S b w o bo)
| o == 0 = s | o == 0 = s
| otherwise = (S (b `mappend` B.singleton w) 0 0 bo) | otherwise = (S (b `mappend` B.singleton w) 0 0 bo)
getOffset :: BitPut Int
getOffset = BitPut $ \s@(S _ _ o _) -> PairS o s
-- | Run the 'BitPut' monad inside 'Put'. -- | Run the 'BitPut' monad inside 'Put'.
runBitPut :: BitPut () -> Put.Put runBitPut :: BitPut () -> Put.Put
runBitPut m = Put.putBuilder b runBitPut m = Put.putBuilder b
@ -208,3 +213,15 @@ instance BitOrderable BitPut where
setBitOrder bo = BitPut $ \(S bu b o _) -> PairS () (S bu b o bo) setBitOrder bo = BitPut $ \(S bu b o _) -> PairS () (S bu b o bo)
getBitOrder = BitPut $ \s@(S _ _ _ bo) -> PairS bo s getBitOrder = BitPut $ \s@(S _ _ _ bo) -> PairS bo s
instance Alignable BitPut where
-- | Skip the given number of bits
skipBits n
| n <= 64 = putWord64 n 0
| otherwise = putWord64 64 0 >> skipBits (n-64)
-- | Skip bits if necessary to align to the next byte
alignByte = do
o <- getOffset
when (o /= 0) $
skipBits (8-o)

View File

@ -25,7 +25,8 @@ library
exposed-modules: Data.Binary.Bits , exposed-modules: Data.Binary.Bits ,
Data.Binary.Bits.Put , Data.Binary.Bits.Put ,
Data.Binary.Bits.Get , Data.Binary.Bits.Get ,
Data.Binary.Bits.BitOrder Data.Binary.Bits.BitOrder ,
Data.Binary.Bits.Alignment
other-modules: Data.Binary.Bits.Internal other-modules: Data.Binary.Bits.Internal
default-language: Haskell98 default-language: Haskell98
@ -36,6 +37,7 @@ test-suite qc
type: exitcode-stdio-1.0 type: exitcode-stdio-1.0
main-is: BitsQC.hs main-is: BitsQC.hs
default-language: Haskell98 default-language: Haskell98
--ghc-options: -O2 -Wall
build-depends: base==4.*, binary >= 0.6.0.0, bytestring, build-depends: base==4.*, binary >= 0.6.0.0, bytestring,
QuickCheck>=2, random, QuickCheck>=2, random,