Use a macro definition for assertM

So that the error location is reported correctly by the compiler when
the assert is hit.
This commit is contained in:
Harendra Kumar 2022-08-13 15:07:44 +05:30
parent 63ba027ce4
commit cc3bbd76dc
7 changed files with 51 additions and 48 deletions

View File

@ -10,23 +10,11 @@
-- Additional "Control.Exception" utilities. -- Additional "Control.Exception" utilities.
module Streamly.Internal.Control.Exception module Streamly.Internal.Control.Exception
( assertM ( verify
, verify
, verifyM , verifyM
) )
where where
import Control.Exception (assert)
-- Like 'assert' but returns @()@ in an 'Applicative' context so that it can be
-- used as an independent statement in a @do@ block.
--
-- /Pre-release/
--
{-# INLINE assertM #-}
assertM :: Applicative f => Bool -> f ()
assertM predicate = assert predicate (pure ())
-- | Like 'assert' but is not removed by the compiler, it is always present in -- | Like 'assert' but is not removed by the compiler, it is always present in
-- production code. -- production code.
-- --

View File

@ -210,11 +210,11 @@ where
-- When we use a purely lazy Monad like Identity, we need to force ordering of -- When we use a purely lazy Monad like Identity, we need to force ordering of
-- some actions for correctness. -- some actions for correctness.
#include "assert.hs"
#include "inline.hs" #include "inline.hs"
#include "ArrayMacros.h" #include "ArrayMacros.h"
#include "MachDeps.h" #include "MachDeps.h"
import Control.Exception (assert)
import Control.DeepSeq (NFData(..), NFData1(..)) import Control.DeepSeq (NFData(..), NFData1(..))
import Control.Monad (when, void) import Control.Monad (when, void)
import Control.Monad.IO.Class (MonadIO(..)) import Control.Monad.IO.Class (MonadIO(..))
@ -225,7 +225,6 @@ import Data.Semigroup (Semigroup(..))
import Data.Word (Word8) import Data.Word (Word8)
import Foreign.C.Types (CSize(..), CInt(..)) import Foreign.C.Types (CSize(..), CInt(..))
import Foreign.Ptr (plusPtr, minusPtr, nullPtr) import Foreign.Ptr (plusPtr, minusPtr, nullPtr)
import Streamly.Internal.Control.Exception (assertM)
import Streamly.Internal.Data.Unboxed import Streamly.Internal.Data.Unboxed
( ArrayContents(..) ( ArrayContents(..)
, Unboxed , Unboxed
@ -722,7 +721,7 @@ roundDownTo elemSize size = size - (size `mod` elemSize)
{-# NOINLINE reallocAligned #-} {-# NOINLINE reallocAligned #-}
reallocAligned :: Int -> Int -> Int -> Array a -> IO (Array a) reallocAligned :: Int -> Int -> Int -> Array a -> IO (Array a)
reallocAligned elemSize alignSize newCapacityInBytes Array{..} = do reallocAligned elemSize alignSize newCapacityInBytes Array{..} = do
assertM (aEnd <= aBound) assertM(aEnd <= aBound)
-- Allocate new array -- Allocate new array
let newCapMaxInBytes = roundUpLargeArray newCapacityInBytes let newCapMaxInBytes = roundUpLargeArray newCapacityInBytes
@ -776,7 +775,7 @@ reallocWith label capSizer minIncrBytes arr = do
newCapBytes = capSizer oldSizeBytes newCapBytes = capSizer oldSizeBytes
newSizeBytes = oldSizeBytes + minIncrBytes newSizeBytes = oldSizeBytes + minIncrBytes
safeCapBytes = max newCapBytes newSizeBytes safeCapBytes = max newCapBytes newSizeBytes
assertM (safeCapBytes >= newSizeBytes || error (badSize newSizeBytes)) assertM(safeCapBytes >= newSizeBytes || error (badSize newSizeBytes))
realloc safeCapBytes arr realloc safeCapBytes arr
@ -2023,8 +2022,8 @@ fromListRev xs = fromListRevN (Prelude.length xs) xs
{-# INLINE putSliceUnsafe #-} {-# INLINE putSliceUnsafe #-}
putSliceUnsafe :: MonadIO m => Array a -> Int -> Array a -> Int -> Int -> m () putSliceUnsafe :: MonadIO m => Array a -> Int -> Array a -> Int -> Int -> m ()
putSliceUnsafe src srcStartBytes dst dstStartBytes lenBytes = liftIO $ do putSliceUnsafe src srcStartBytes dst dstStartBytes lenBytes = liftIO $ do
assertM (lenBytes <= aBound dst - dstStartBytes) assertM(lenBytes <= aBound dst - dstStartBytes)
assertM (lenBytes <= aEnd src - srcStartBytes) assertM(lenBytes <= aEnd src - srcStartBytes)
let !(I# srcStartBytes#) = srcStartBytes let !(I# srcStartBytes#) = srcStartBytes
!(I# dstStartBytes#) = dstStartBytes !(I# dstStartBytes#) = dstStartBytes
!(I# lenBytes#) = lenBytes !(I# lenBytes#) = lenBytes
@ -2065,7 +2064,7 @@ spliceUnsafe dst src =
let startSrc = arrStart src let startSrc = arrStart src
srcLen = aEnd src - startSrc srcLen = aEnd src - startSrc
endDst = aEnd dst endDst = aEnd dst
assertM (endDst + srcLen <= aBound dst) assertM(endDst + srcLen <= aBound dst)
putSliceUnsafe src startSrc dst endDst srcLen putSliceUnsafe src startSrc dst endDst srcLen
return $ dst {aEnd = endDst + srcLen} return $ dst {aEnd = endDst + srcLen}

View File

@ -196,12 +196,13 @@ module Streamly.Internal.Data.Parser.ParserD
) )
where where
#include "assert.hs"
import Control.Exception (Exception) import Control.Exception (Exception)
import Control.Monad (when) import Control.Monad (when)
import Control.Monad.Catch (MonadCatch, MonadThrow(..)) import Control.Monad.Catch (MonadCatch, MonadThrow(..))
import Data.Bifunctor (first) import Data.Bifunctor (first)
import Fusion.Plugin.Types (Fuse(..)) import Fusion.Plugin.Types (Fuse(..))
import Streamly.Internal.Control.Exception (assertM)
import Streamly.Internal.Data.Fold.Type (Fold(..)) import Streamly.Internal.Data.Fold.Type (Fold(..))
import Streamly.Internal.Data.SVar.Type (defState) import Streamly.Internal.Data.SVar.Type (defState)
import Streamly.Internal.Data.Either.Strict (Either'(..)) import Streamly.Internal.Data.Either.Strict (Either'(..))
@ -1705,17 +1706,17 @@ takeP lim (Parser pstep pinitial pextract) = Parser step initial extract
IError e -> return $ IError e IError e -> return $ IError e
step (Tuple' cnt r) a = do step (Tuple' cnt r) a = do
assertM (cnt < lim) assertM(cnt < lim)
res <- pstep r a res <- pstep r a
let cnt1 = cnt + 1 let cnt1 = cnt + 1
case res of case res of
Partial 0 s -> do Partial 0 s -> do
assertM (cnt1 >= 0) assertM(cnt1 >= 0)
if cnt1 < lim if cnt1 < lim
then return $ Partial 0 $ Tuple' cnt1 s then return $ Partial 0 $ Tuple' cnt1 s
else Done 0 <$> pextract s else Done 0 <$> pextract s
Continue 0 s -> do Continue 0 s -> do
assertM (cnt1 >= 0) assertM(cnt1 >= 0)
if cnt1 < lim if cnt1 < lim
then return $ Continue 0 $ Tuple' cnt1 s then return $ Continue 0 $ Tuple' cnt1 s
-- XXX This should error out? -- XXX This should error out?
@ -1732,11 +1733,11 @@ takeP lim (Parser pstep pinitial pextract) = Parser step initial extract
else Done 0 <$> pextract s else Done 0 <$> pextract s
Partial n s -> do Partial n s -> do
let taken = cnt1 - n let taken = cnt1 - n
assertM (taken >= 0) assertM(taken >= 0)
return $ Partial n $ Tuple' taken s return $ Partial n $ Tuple' taken s
Continue n s -> do Continue n s -> do
let taken = cnt1 - n let taken = cnt1 - n
assertM (taken >= 0) assertM(taken >= 0)
return $ Continue n $ Tuple' taken s return $ Continue n $ Tuple' taken s
Done n b -> return $ Done n b Done n b -> return $ Done n b
Error str -> return $ Error str Error str -> return $ Error str
@ -2131,7 +2132,7 @@ manyTill (Fold fstep finitial fextract)
case r of case r of
Partial n s -> return $ Partial n (ManyTillR 0 fs s) Partial n s -> return $ Partial n (ManyTillR 0 fs s)
Continue n s -> do Continue n s -> do
assertM (cnt + 1 - n >= 0) assertM(cnt + 1 - n >= 0)
return $ Continue n (ManyTillR (cnt + 1 - n) fs s) return $ Continue n (ManyTillR (cnt + 1 - n) fs s)
Done n _ -> do Done n _ -> do
b <- fextract fs b <- fextract fs
@ -2157,7 +2158,7 @@ manyTill (Fold fstep finitial fextract)
case r of case r of
Partial n s -> return $ Partial n (ManyTillL 0 fs s) Partial n s -> return $ Partial n (ManyTillL 0 fs s)
Continue n s -> do Continue n s -> do
assertM (cnt + 1 - n >= 0) assertM(cnt + 1 - n >= 0)
return $ Continue n (ManyTillL (cnt + 1 - n) fs s) return $ Continue n (ManyTillL (cnt + 1 - n) fs s)
Done n b -> do Done n b -> do
fs1 <- fstep fs b fs1 <- fstep fs b

View File

@ -207,6 +207,8 @@ module Streamly.Internal.Data.Parser.ParserD.Type
) )
where where
#include "assert.hs"
import Control.Applicative (Alternative(..), liftA2) import Control.Applicative (Alternative(..), liftA2)
import Control.Exception (Exception(..)) import Control.Exception (Exception(..))
import Control.Monad (MonadPlus(..), (>=>)) import Control.Monad (MonadPlus(..), (>=>))
@ -216,7 +218,6 @@ import Control.Monad.State.Class (MonadState, get, put)
import Control.Monad.Catch (MonadCatch, try, throwM, MonadThrow) import Control.Monad.Catch (MonadCatch, try, throwM, MonadThrow)
import Data.Bifunctor (Bifunctor(..)) import Data.Bifunctor (Bifunctor(..))
import Fusion.Plugin.Types (Fuse(..)) import Fusion.Plugin.Types (Fuse(..))
import Streamly.Internal.Control.Exception (assertM)
import Streamly.Internal.Data.Fold.Type (Fold(..), toList) import Streamly.Internal.Data.Fold.Type (Fold(..), toList)
import Streamly.Internal.Data.Tuple.Strict (Tuple3'(..)) import Streamly.Internal.Data.Tuple.Strict (Tuple3'(..))
@ -505,15 +506,15 @@ parseDToK pstep initial extract leftover (level, count) cont = do
pRes <- pstep r x pRes <- pstep r x
case pRes of case pRes of
Done n b -> do Done n b -> do
assertM (n <= cnt1) assertM(n <= cnt1)
cont (level, cnt1 - n) (K.Success n b) cont (level, cnt1 - n) (K.Success n b)
Error err -> Error err ->
cont (level, cnt1) (K.Failure err) cont (level, cnt1) (K.Failure err)
Partial n pst1 -> do Partial n pst1 -> do
assertM (n <= cnt1) assertM(n <= cnt1)
return $ K.Partial n (parseCont (cnt1 - n) (return pst1)) return $ K.Partial n (parseCont (cnt1 - n) (return pst1))
Continue n pst1 -> do Continue n pst1 -> do
assertM (n <= cnt1) assertM(n <= cnt1)
return $ K.Continue n (parseCont (cnt1 - n) (return pst1)) return $ K.Continue n (parseCont (cnt1 - n) (return pst1))
parseCont cnt acc Nothing = do parseCont cnt acc Nothing = do
pst <- acc pst <- acc
@ -584,7 +585,7 @@ fromParserK parser0 = Parser step initial extract
-- always transitions to only FPKCont. The input remains unconsumed in -- always transitions to only FPKCont. The input remains unconsumed in
-- this case so we use "n + 1". -- this case so we use "n + 1".
step (FPKDone n b) _ = do step (FPKDone n b) _ = do
assertM (n == 0) assertM(n == 0)
return $ Done (n + 1) b return $ Done (n + 1) b
step (FPKCont cont) a = do step (FPKCont cont) a = do
r <- cont (Just a) r <- cont (Just a)
@ -981,7 +982,7 @@ alt (Parser stepL initialL extractL) (Parser stepR initialR extractR) =
case r of case r of
Partial n s -> return $ Partial n (AltParseL 0 s) Partial n s -> return $ Partial n (AltParseL 0 s)
Continue n s -> do Continue n s -> do
assertM (cnt + 1 - n >= 0) assertM(cnt + 1 - n >= 0)
return $ Continue n (AltParseL (cnt + 1 - n) s) return $ Continue n (AltParseL (cnt + 1 - n) s)
Done n b -> return $ Done n b Done n b -> return $ Done n b
Error _ -> do Error _ -> do
@ -1038,13 +1039,13 @@ splitMany (Parser step1 initial1 extract1) (Fold fstep finitial fextract) =
let cnt1 = cnt + 1 let cnt1 = cnt + 1
case r of case r of
Partial n s -> do Partial n s -> do
assertM (cnt1 - n >= 0) assertM(cnt1 - n >= 0)
return $ Continue n (Tuple3' s (cnt1 - n) fs) return $ Continue n (Tuple3' s (cnt1 - n) fs)
Continue n s -> do Continue n s -> do
assertM (cnt1 - n >= 0) assertM(cnt1 - n >= 0)
return $ Continue n (Tuple3' s (cnt1 - n) fs) return $ Continue n (Tuple3' s (cnt1 - n) fs)
Done n b -> do Done n b -> do
assertM (cnt1 - n >= 0) assertM(cnt1 - n >= 0)
fstep fs b >>= handleCollect (Partial n) (Done n) fstep fs b >>= handleCollect (Partial n) (Done n)
Error _ -> do Error _ -> do
xs <- fextract fs xs <- fextract fs
@ -1098,13 +1099,13 @@ splitManyPost (Parser step1 initial1 extract1) (Fold fstep finitial fextract) =
let cnt1 = cnt + 1 let cnt1 = cnt + 1
case r of case r of
Partial n s -> do Partial n s -> do
assertM (cnt1 - n >= 0) assertM(cnt1 - n >= 0)
return $ Continue n (Tuple3' s (cnt1 - n) fs) return $ Continue n (Tuple3' s (cnt1 - n) fs)
Continue n s -> do Continue n s -> do
assertM (cnt1 - n >= 0) assertM(cnt1 - n >= 0)
return $ Continue n (Tuple3' s (cnt1 - n) fs) return $ Continue n (Tuple3' s (cnt1 - n) fs)
Done n b -> do Done n b -> do
assertM (cnt1 - n >= 0) assertM(cnt1 - n >= 0)
fstep fs b >>= handleCollect (Partial n) (Done n) fstep fs b >>= handleCollect (Partial n) (Done n)
Error _ -> do Error _ -> do
xs <- fextract fs xs <- fextract fs
@ -1171,13 +1172,13 @@ splitSome (Parser step1 initial1 extract1) (Fold fstep finitial fextract) =
let cnt1 = cnt + 1 let cnt1 = cnt + 1
case r of case r of
Partial n s -> do Partial n s -> do
assertM (cnt1 - n >= 0) assertM(cnt1 - n >= 0)
return $ Continue n (Tuple3' s (cnt1 - n) (Left fs)) return $ Continue n (Tuple3' s (cnt1 - n) (Left fs))
Continue n s -> do Continue n s -> do
assertM (cnt1 - n >= 0) assertM(cnt1 - n >= 0)
return $ Continue n (Tuple3' s (cnt1 - n) (Left fs)) return $ Continue n (Tuple3' s (cnt1 - n) (Left fs))
Done n b -> do Done n b -> do
assertM (cnt1 - n >= 0) assertM(cnt1 - n >= 0)
fstep fs b >>= handleCollect (Partial n) (Done n) fstep fs b >>= handleCollect (Partial n) (Done n)
Error err -> return $ Error err Error err -> return $ Error err
step (Tuple3' st cnt (Right fs)) a = do step (Tuple3' st cnt (Right fs)) a = do
@ -1185,13 +1186,13 @@ splitSome (Parser step1 initial1 extract1) (Fold fstep finitial fextract) =
let cnt1 = cnt + 1 let cnt1 = cnt + 1
case r of case r of
Partial n s -> do Partial n s -> do
assertM (cnt1 - n >= 0) assertM(cnt1 - n >= 0)
return $ Partial n (Tuple3' s (cnt1 - n) (Right fs)) return $ Partial n (Tuple3' s (cnt1 - n) (Right fs))
Continue n s -> do Continue n s -> do
assertM (cnt1 - n >= 0) assertM(cnt1 - n >= 0)
return $ Continue n (Tuple3' s (cnt1 - n) (Right fs)) return $ Continue n (Tuple3' s (cnt1 - n) (Right fs))
Done n b -> do Done n b -> do
assertM (cnt1 - n >= 0) assertM(cnt1 - n >= 0)
fstep fs b >>= handleCollect (Partial n) (Done n) fstep fs b >>= handleCollect (Partial n) (Done n)
Error _ -> Done cnt1 <$> fextract fs Error _ -> Done cnt1 <$> fextract fs

6
core/src/assert.hs Normal file
View File

@ -0,0 +1,6 @@
-- A convenient macro to assert in a do block. We cannot define this as a
-- Haskell function because then the compiler reports the assert location
-- inside the wrapper function rather than the original location.
import Control.Exception (assert)
#define assertM(p) assert (p) (return ())

6
src/assert.hs Normal file
View File

@ -0,0 +1,6 @@
-- A convenient macro to assert in a do block. We cannot define this as a
-- Haskell function because then the compiler reports the assert location
-- inside the wrapper function rather than the original location.
import Control.Exception (assert)
#define assertM(p) assert (p) (return ())

View File

@ -85,6 +85,7 @@ extra-source-files:
src/Streamly/Internal/Data/Array/ArrayMacros.h src/Streamly/Internal/Data/Array/ArrayMacros.h
src/Streamly/Internal/FileSystem/Event/Darwin.h src/Streamly/Internal/FileSystem/Event/Darwin.h
src/assert.hs
src/config.h.in src/config.h.in
src/inline.hs src/inline.hs
test/Streamly/Test/Data/*.hs test/Streamly/Test/Data/*.hs
@ -124,11 +125,12 @@ extra-source-files:
-- This is temporary as we will soon break this package out -- This is temporary as we will soon break this package out
core/configure core/configure
core/configure.ac core/configure.ac
core/src/assert.hs
core/src/config.h.in
core/src/inline.hs
core/src/Streamly/Internal/Data/Stream/Instances.hs core/src/Streamly/Internal/Data/Stream/Instances.hs
core/src/Streamly/Internal/Data/Array/ArrayMacros.h core/src/Streamly/Internal/Data/Array/ArrayMacros.h
core/src/inline.hs
core/src/Streamly/Internal/Data/Time/Clock/config-clock.h core/src/Streamly/Internal/Data/Time/Clock/config-clock.h
core/src/config.h.in
core/src/Streamly/Internal/BaseCompat.hs core/src/Streamly/Internal/BaseCompat.hs
core/src/Streamly/Internal/Control/Exception.hs core/src/Streamly/Internal/Control/Exception.hs
core/src/Streamly/Internal/Control/Monad.hs core/src/Streamly/Internal/Control/Monad.hs