diff --git a/src/Analysis/Abstract/BadValues.hs b/src/Analysis/Abstract/BadValues.hs index 68eeb6b97..94b4dcd7d 100644 --- a/src/Analysis/Abstract/BadValues.hs +++ b/src/Analysis/Abstract/BadValues.hs @@ -38,6 +38,7 @@ instance ( Effectful m BitwiseError{} -> hole >>= yield Bitwise2Error{} -> hole >>= yield KeyValueError{} -> hole >>= \x -> yield (x, x) + ArithmeticError{} -> hole >>= yield ) analyzeModule = liftAnalyze analyzeModule diff --git a/src/Control/Abstract/Value.hs b/src/Control/Abstract/Value.hs index fa7fa1908..2c02315e4 100644 --- a/src/Control/Abstract/Value.hs +++ b/src/Control/Abstract/Value.hs @@ -204,6 +204,8 @@ data ValueError location value resume where BitwiseError :: value -> ValueError location value value Bitwise2Error :: value -> value -> ValueError location value value KeyValueError :: value -> ValueError location value (value, value) + -- Indicates that we encountered an arithmetic exception inside Haskell-native number crunching. + ArithmeticError :: ArithException -> ValueError location value value instance Eq value => Eq1 (ValueError location value) where liftEq _ (StringError a) (StringError b) = a == b diff --git a/src/Data/Abstract/Value.hs b/src/Data/Abstract/Value.hs index 256af31c5..d47faef94 100644 --- a/src/Data/Abstract/Value.hs +++ b/src/Data/Abstract/Value.hs @@ -7,6 +7,7 @@ import qualified Data.Abstract.Environment as Env import Data.Abstract.Evaluatable import qualified Data.Abstract.Number as Number import Data.Scientific (Scientific) +import Data.Scientific.Exts import qualified Data.Set as Set import Prologue hiding (TypeError) import Prelude hiding (Float, Integer, String, Rational) @@ -265,22 +266,25 @@ instance (Monad (m effects), MonadEvaluatable location term (Value location) eff | otherwise = throwValueError (NumericError arg) liftNumeric2 f left right - | Just (Integer i, Integer j) <- prjPair pair = f i j & specialize - | Just (Integer i, Rational j) <- prjPair pair = f i j & specialize - | Just (Integer i, Float j) <- prjPair pair = f i j & specialize - | Just (Rational i, Integer j) <- prjPair pair = f i j & specialize - | Just (Rational i, Rational j) <- prjPair pair = f i j & specialize - | Just (Rational i, Float j) <- prjPair pair = f i j & specialize - | Just (Float i, Integer j) <- prjPair pair = f i j & specialize - | Just (Float i, Rational j) <- prjPair pair = f i j & specialize - | Just (Float i, Float j) <- prjPair pair = f i j & specialize + | Just (Integer i, Integer j) <- prjPair pair = tentative f i j & specialize + | Just (Integer i, Rational j) <- prjPair pair = tentative f i j & specialize + | Just (Integer i, Float j) <- prjPair pair = tentative f i j & specialize + | Just (Rational i, Integer j) <- prjPair pair = tentative f i j & specialize + | Just (Rational i, Rational j) <- prjPair pair = tentative f i j & specialize + | Just (Rational i, Float j) <- prjPair pair = tentative f i j & specialize + | Just (Float i, Integer j) <- prjPair pair = tentative f i j & specialize + | Just (Float i, Rational j) <- prjPair pair = tentative f i j & specialize + | Just (Float i, Float j) <- prjPair pair = tentative f i j & specialize | otherwise = throwValueError (Numeric2Error left right) where + tentative x i j = attemptUnsafeArithmetic (x i j) + -- Dispatch whatever's contained inside a 'Number.SomeNumber' to its appropriate 'MonadValue' ctor - specialize :: MonadValue location value effects m => Number.SomeNumber -> m effects value - specialize (Number.SomeNumber (Number.Integer i)) = integer i - specialize (Number.SomeNumber (Number.Ratio r)) = rational r - specialize (Number.SomeNumber (Number.Decimal d)) = float d + specialize :: MonadEvaluatable location term value effects m => Either ArithException Number.SomeNumber -> m effects value + specialize (Left exc) = throwValueError (ArithmeticError exc) + specialize (Right (Number.SomeNumber (Number.Integer i))) = integer i + specialize (Right (Number.SomeNumber (Number.Ratio r))) = rational r + specialize (Right (Number.SomeNumber (Number.Decimal d))) = float d pair = (left, right) liftComparison comparator left right diff --git a/src/Data/Scientific/Exts.hs b/src/Data/Scientific/Exts.hs index 50df2a55d..d19d456fd 100644 --- a/src/Data/Scientific/Exts.hs +++ b/src/Data/Scientific/Exts.hs @@ -1,9 +1,11 @@ module Data.Scientific.Exts ( module Data.Scientific + , attemptUnsafeArithmetic , parseScientific ) where import Control.Applicative +import Control.Exception as Exc (evaluate, try) import Control.Monad hiding (fail) import Data.Attoparsec.ByteString.Char8 import Data.ByteString.Char8 hiding (readInt, takeWhile) @@ -13,6 +15,7 @@ import Numeric import Prelude hiding (fail, filter, null, takeWhile) import Prologue hiding (null) import Text.Read (readMaybe) +import System.IO.Unsafe parseScientific :: ByteString -> Either String Scientific parseScientific = parseOnly parser @@ -96,3 +99,10 @@ parser = signed (choice [hex, oct, bin, dec]) where let trail = if null trailings then "0" else trailings attempt (unpack (leads <> "." <> trail <> exponent)) + +-- | Attempt to evaluate the given term into WHNF. If doing so raises an 'ArithException', such as +-- 'ZeroDivisionError' or 'RatioZeroDenominator', 'Left' will be returned. +-- Hooray for uncatchable exceptions that bubble up from third-party code. +attemptUnsafeArithmetic :: a -> Either ArithException a +attemptUnsafeArithmetic = unsafePerformIO . Exc.try . evaluate +{-# NOINLINE attemptUnsafeArithmetic #-}