Factor out compatibility shims into GHC.Num.Compat

This commit is contained in:
Ryan Scott 2022-01-13 12:13:40 -05:00
parent 7bdb84beca
commit 1181076f1f
7 changed files with 200 additions and 52 deletions

View File

@ -41,6 +41,7 @@ library
Default-language:
Haskell2010
Build-depends: base >= 4.8 && < 5,
arithmoi >= 0.12,
async >= 2.2 && < 2.3,
base-compat >= 0.6 && < 0.12,
bv-sized >= 1.0 && < 1.1,
@ -56,8 +57,6 @@ library
ghc-prim,
GraphSCC >= 1.0.4,
heredoc >= 0.2,
ghc-bignum,
arithmoi,
libBF >= 0.6 && < 0.7,
MemoTrie >= 0.6 && < 0.7,
monad-control >= 1.0,
@ -77,6 +76,11 @@ library
panic >= 0.3,
what4 >= 1.2 && < 1.3
if impl(ghc >= 9.0)
build-depends: ghc-bignum >= 1.0 && < 1.3
else
build-depends: integer-gmp >= 1.0 && < 1.1
Build-tool-depends: alex:alex, happy:happy
hs-source-dirs: src
@ -202,6 +206,7 @@ library
Other-modules: Cryptol.Parser.LexerUtils,
Cryptol.Parser.ParserUtils,
Cryptol.Prelude,
GHC.Num.Compat,
Paths_cryptol,
GitRev

View File

@ -42,7 +42,7 @@ import Data.Bits
import Data.Ratio
import Numeric (showIntAtBase)
import qualified LibBF as FP
import qualified GHC.Num.Integer as Integer
import qualified GHC.Num.Compat as Integer
import qualified Cryptol.Backend.Arch as Arch
import qualified Cryptol.Backend.FloatHelpers as FP
@ -343,8 +343,8 @@ instance Backend Concrete where
-- the only values for which no inverse exists are
-- congruent to 0 modulo m.
znRecip sym m x =
case Integer.integerRecipMod# x (Integer.integerToNaturalClamp m) of
(# r | #) -> integerLit sym (toInteger r)
case Integer.integerRecipMod x m of
(# r | #) -> integerLit sym r
(# | () #) -> raiseError sym DivideByZero
znPlus _ = liftBinIntMod (+)

View File

@ -40,7 +40,7 @@ import Control.Monad.IO.Class (MonadIO(..))
import Data.Bits (bit, complement)
import Data.List (foldl')
import qualified GHC.Num.Integer as Integer
import qualified GHC.Num.Compat as Integer
import Data.SBV.Dynamic as SBV
import qualified Data.SBV.Internals as SBV
@ -431,8 +431,8 @@ sModRecip _sym 0 _ = panic "sModRecip" ["0 modulus not allowed"]
sModRecip sym m x
-- If the input is concrete, evaluate the answer
| Just xi <- svAsInteger x
= case Integer.integerRecipMod# xi (Integer.integerToNaturalClamp m) of
(# r | #) -> integerLit sym (toInteger r)
= case Integer.integerRecipMod xi m of
(# r | #) -> integerLit sym r
(# | () #) -> raiseError sym DivideByZero
-- If the input is symbolic, create a new symbolic constant

View File

@ -30,7 +30,7 @@ import Data.Text (Text)
import Data.Parameterized.NatRepr
import Data.Parameterized.Some
import qualified GHC.Num.Integer as Integer
import qualified GHC.Num.Compat as Integer
import qualified What4.Interface as W4
import qualified What4.SWord as SW
@ -343,7 +343,7 @@ instance W4.IsSymExprBuilder sym => Backend (What4 sym) where
wordMult sym x y = liftIO (SW.bvMul (w4 sym) x y)
wordNegate sym x = liftIO (SW.bvNeg (w4 sym) x)
wordLg2 sym x = sLg2 (w4 sym) x
wordShiftLeft sym x y = w4bvShl (w4 sym) x y
wordShiftRight sym x y = w4bvLshr (w4 sym) x y
wordRotateLeft sym x y = w4bvRol (w4 sym) x y
@ -670,8 +670,8 @@ sModRecip _sym 0 _ = panic "sModRecip" ["0 modulus not allowed"]
sModRecip sym m x
-- If the input is concrete, evaluate the answer
| Just xi <- W4.asInteger x
= case Integer.integerRecipMod# xi (Integer.integerToNaturalClamp m) of
(# r | #) -> integerLit sym (toInteger r)
= case Integer.integerRecipMod xi m of
(# r | #) -> integerLit sym r
(# | () #) -> raiseError sym DivideByZero
-- If the input is symbolic, create a new symbolic constant

View File

@ -33,7 +33,7 @@
> import qualified Data.Text as T (pack)
> import LibBF (BigFloat)
> import qualified LibBF as FP
> import qualified GHC.Num.Integer as Integer
> import qualified GHC.Num.Compat as Integer
>
> import Cryptol.ModuleSystem.Name (asPrim)
> import Cryptol.TypeCheck.Solver.InfNat (Nat'(..), nAdd, nMin, nMul)
@ -1334,8 +1334,8 @@ confused with integral division).
>
> zRecip :: Integer -> Integer -> E Integer
> zRecip m x =
> case Integer.integerRecipMod# x (Integer.integerToNaturalClamp m) of
> (# r | #) -> pure (toInteger r)
> case Integer.integerRecipMod x m of
> (# r | #) -> pure r
> (# | () #) -> cryError DivideByZero
>
> zDiv :: Integer -> Integer -> Integer -> E Integer

View File

@ -29,8 +29,8 @@ module Cryptol.PrimeEC
, primeModulus
, ProjectivePoint(..)
, toProjectivePoint
, integerToBigNat
, bigNatToInteger
, BN.integerToBigNat
, BN.bigNatToInteger
, ec_double
, ec_add_nonzero
@ -39,12 +39,15 @@ module Cryptol.PrimeEC
) where
{-
import GHC.Num.BigNat (BigNat#)
import qualified GHC.Num.Backend as BN
import qualified GHC.Num.BigNat as BN
import qualified GHC.Num.Integer as BN
import GHC.Prim
import GHC.Types
-}
import GHC.Num.Compat (BigNat#)
import qualified GHC.Num.Compat as BN
import GHC.Exts
import Cryptol.TypeCheck.Solver.InfNat (widthInteger)
import Cryptol.Utils.Panic
@ -61,20 +64,12 @@ data ProjectivePoint =
toProjectivePoint :: Integer -> Integer -> Integer -> ProjectivePoint
toProjectivePoint x y z =
ProjectivePoint (integerToBigNat x) (integerToBigNat y) (integerToBigNat z)
ProjectivePoint (BN.integerToBigNat x) (BN.integerToBigNat y) (BN.integerToBigNat z)
-- | The projective "point at infinity", which represents the zero element
-- of the ECC group.
zro :: ProjectivePoint
zro = ProjectivePoint (BN.bigNatFromWord# 1##) (BN.bigNatFromWord# 1##) (BN.bigNatFromWord# 0##)
-- | Coerce an integer value to a @BigNat@. This operation only really makes
-- sense for nonnegative values, but this condition is not checked.
integerToBigNat :: Integer -> BigNat#
integerToBigNat = BN.integerToBigNatClamp#
bigNatToInteger :: BigNat# -> Integer
bigNatToInteger = BN.integerFromBigNat#
zro = ProjectivePoint (BN.oneBigNat (# #)) (BN.oneBigNat (# #)) (BN.zeroBigNat (# #))
-- | Simple newtype wrapping the @BigNat@ value of the
-- modulus of the underlying field Z p. This modulus
@ -85,7 +80,7 @@ newtype PrimeModulus = PrimeModulus { primeMod :: BigNat# }
-- | Inject an integer value into the @PrimeModulus@ type.
-- This modulus is required to be prime.
primeModulus :: Integer -> PrimeModulus
primeModulus x = PrimeModulus (integerToBigNat x)
primeModulus x = PrimeModulus (BN.integerToBigNat x)
{-# INLINE primeModulus #-}
@ -104,10 +99,10 @@ mod_add p x y =
-- in @Z p@ when @p > 2@. The input @x@ is required to be in reduced form,
-- and will output a value in reduced form.
mod_half :: PrimeModulus -> BigNat# -> BigNat#
mod_half p x = if BN.bigNatTestBit x 0 then qodd else qeven
mod_half p x = if BN.testBitBigNat x 0# then qodd else qeven
where
qodd = (BN.bigNatAdd x (primeMod p)) `BN.bigNatShiftR#` 1##
qeven = x `BN.bigNatShiftR#` 1##
qodd = (BN.bigNatAdd x (primeMod p)) `BN.shiftRBigNat` 1#
qeven = x `BN.shiftRBigNat` 1#
-- | Compute the modular multiplication of two input values. Currently, this
-- uses naive modular reduction, and does not require the inputs to be in
@ -134,7 +129,7 @@ mod_square p x = BN.bigNatSqr x `BN.bigNatRem` primeMod p
-- will be in reduced form.
mul2 :: PrimeModulus -> BigNat# -> BigNat#
mul2 p x =
let r = x `BN.bigNatShiftL#` 1## in
let r = x `BN.shiftLBigNat` 1# in
case BN.bigNatSub r (primeMod p) of
(# (# #) | #) -> r
(# | rmp #) -> rmp
@ -206,7 +201,7 @@ ec_sub :: PrimeModulus -> ProjectivePoint -> ProjectivePoint -> ProjectivePoint
ec_sub p s t = ec_add p s u
where u = case BN.bigNatSub (primeMod p) (py t) of
(# | y' #) -> t{ py = y' }
(# (# #) | #) -> panic "ec_sub" ["cooridnate not in reduced form!", show (bigNatToInteger (py t))]
(# (# #) | #) -> panic "ec_sub" ["cooridnate not in reduced form!", show (BN.bigNatToInteger (py t))]
{-# INLINE ec_sub #-}
@ -275,11 +270,11 @@ ec_add_nonzero p s@(ProjectivePoint sx sy sz) (ProjectivePoint tx ty tz) =
ec_normalize :: PrimeModulus -> ProjectivePoint -> ProjectivePoint
ec_normalize p s@(ProjectivePoint x y z)
| BN.bigNatIsOne z = s
| otherwise = ProjectivePoint x' y' (BN.bigNatFromWord# 1##)
| otherwise = ProjectivePoint x' y' (BN.oneBigNat (# #))
where
m = primeMod p
l = BN.sbignat_recip_mod 0# z m
l = BN.recipModBigNat z m
l2 = BN.bigNatSqr l
l3 = BN.bigNatMul l l2
@ -297,15 +292,15 @@ ec_mult p d s
| BN.bigNatIsZero (pz s) = zro
| otherwise =
case m of
0# -> panic "ec_mult" ["modulus too large", show (bigNatToInteger (primeMod p))]
0# -> panic "ec_mult" ["modulus too large", show (BN.bigNatToInteger (primeMod p))]
_ -> go m zro
where
s' = ec_normalize p s
h = 3*d
d' = integerToBigNat d
h' = integerToBigNat h
d' = BN.integerToBigNat d
h' = BN.integerToBigNat h
m = case widthInteger h of
BN.IS mint -> mint
@ -317,9 +312,8 @@ ec_mult p d s
| otherwise = go (i -# 1#) r'
where
wi = int2Word# i
h_i = isTrue# (BN.bigNatTestBit# h' wi)
d_i = isTrue# (BN.bigNatTestBit# d' wi)
h_i = BN.testBitBigNat h' i
d_i = BN.testBitBigNat d' i
r' = if h_i then
if d_i then r2 else ec_add p r2 s'
@ -389,7 +383,7 @@ normalizeForTwinMult p s t
abcd = mod_mul p a bcd
e = BN.sbignat_recip_mod 0# abcd m
e = BN.recipModBigNat abcd m
a_inv = mod_mul p e bcd
b_inv = mod_mul p e acd
@ -408,11 +402,11 @@ normalizeForTwinMult p s t
d_inv2 = mod_square p d_inv
d_inv3 = mod_mul p d_inv d_inv2
s' = ProjectivePoint (mod_mul p (px s) a_inv2) (mod_mul p (py s) a_inv3) (BN.bigNatFromWord# 1##)
t' = ProjectivePoint (mod_mul p (px t) b_inv2) (mod_mul p (py t) b_inv3) (BN.bigNatFromWord# 1##)
s' = ProjectivePoint (mod_mul p (px s) a_inv2) (mod_mul p (py s) a_inv3) (BN.oneBigNat (# #))
t' = ProjectivePoint (mod_mul p (px t) b_inv2) (mod_mul p (py t) b_inv3) (BN.oneBigNat (# #))
spt' = ProjectivePoint (mod_mul p (px spt) c_inv2) (mod_mul p (py spt) c_inv3) (BN.bigNatFromWord# 1##)
smt' = ProjectivePoint (mod_mul p (px smt) d_inv2) (mod_mul p (py smt) d_inv3) (BN.bigNatFromWord# 1##)
spt' = ProjectivePoint (mod_mul p (px spt) c_inv2) (mod_mul p (py spt) c_inv3) (BN.oneBigNat (# #))
smt' = ProjectivePoint (mod_mul p (px smt) d_inv2) (mod_mul p (py smt) d_inv3) (BN.oneBigNat (# #))
-- | Given an integer @j@ and a projective point @S@, together with
@ -425,15 +419,15 @@ ec_twin_mult :: PrimeModulus ->
Integer -> ProjectivePoint ->
Integer -> ProjectivePoint ->
ProjectivePoint
ec_twin_mult p (integerToBigNat -> d0) s (integerToBigNat -> d1) t =
ec_twin_mult p (BN.integerToBigNat -> d0) s (BN.integerToBigNat -> d1) t =
case m of
0# -> panic "ec_twin_mult" ["modulus too large", show (bigNatToInteger (primeMod p))]
0# -> panic "ec_twin_mult" ["modulus too large", show (BN.bigNatToInteger (primeMod p))]
_ -> go m init_c0 init_c1 zro
where
(s',t',spt',smt') = normalizeForTwinMult p s t
m = case max 4 (widthInteger (bigNatToInteger (primeMod p))) of
m = case max 4 (widthInteger (BN.bigNatToInteger (primeMod p))) of
BN.IS mint -> mint
_ -> 0# -- if `m` doesn't fit into an Int, should be impossible
@ -441,7 +435,7 @@ ec_twin_mult p (integerToBigNat -> d0) s (integerToBigNat -> d1) t =
init_c1 = C False False (tst d1 (m -# 1#)) (tst d1 (m -# 2#)) (tst d1 (m -# 3#)) (tst d1 (m -# 4#))
tst x i
| isTrue# (i >=# 0#) = isTrue# (BN.bigNatTestBit# x (int2Word# i))
| isTrue# (i >=# 0#) = BN.testBitBigNat x i
| otherwise = False
f i =

149
src/GHC/Num/Compat.hs Normal file
View File

@ -0,0 +1,149 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE UnboxedSums #-}
{-# LANGUAGE UnboxedTuples #-}
-- |
-- Module : GHC.Num.Compat
-- Description : Defines numeric compatibility shims that work with both
-- ghc-bignum (GHC 9.0+) and integer-gmp (older GHCs).
-- Copyright : (c) 2021 Galois, Inc.
-- License : BSD3
-- Maintainer : cryptol@galois.com
-- Stability : provisional
-- Portability : portable
module GHC.Num.Compat
( -- * BigNat#
BigNat#
, bigNatAdd
, bigNatIsOne
, bigNatIsZero
, bigNatMul
, bigNatRem
, bigNatSqr
, bigNatSub
, bigNatSubUnsafe
, oneBigNat
, recipModBigNat
, shiftLBigNat
, shiftRBigNat
, testBitBigNat
, zeroBigNat
-- * Integer
, Integer(IS, IP, IN)
, integerRecipMod
-- * Conversions
, bigNatToInteger
, integerToBigNat
) where
#if defined(MIN_VERSION_ghc_bignum)
import GHC.Num.BigNat (BigNat#, bigNatAdd, bigNatIsOne, bigNatIsZero, bigNatMul, bigNatRem, bigNatSqr, bigNatSub, bigNatSubUnsafe)
import qualified GHC.Num.Backend as BN
import qualified GHC.Num.BigNat as BN
import GHC.Num.Integer (Integer(IS, IP, IN))
import qualified GHC.Num.Integer as Integer
import GHC.Exts
-- | Coerce a @BigNat#@ to an integer value.
bigNatToInteger :: BigNat# -> Integer
bigNatToInteger = Integer.integerFromBigNat#
integerRecipMod :: Integer -> Integer -> (# Integer | () #)
integerRecipMod x y =
case Integer.integerRecipMod# x (Integer.integerToNaturalClamp y) of
(# r | #) -> (# toInteger r | #)
(# | () #) -> (# | () #)
-- | Coerce an integer value to a @BigNat#@. This operation only really makes
-- sense for nonnegative values, but this condition is not checked.
integerToBigNat :: Integer -> BigNat#
integerToBigNat = Integer.integerToBigNatClamp#
-- Top-level unlifted bindings aren't allowed, so we fake one with a thunk.
oneBigNat :: (# #) -> BigNat#
oneBigNat _ = BN.bigNatFromWord# 1##
recipModBigNat :: BigNat# -> BigNat# -> BigNat#
recipModBigNat = BN.sbignat_recip_mod 0#
shiftLBigNat :: BigNat# -> Int# -> BigNat#
shiftLBigNat bn i = BN.bigNatShiftL# bn (int2Word# i)
shiftRBigNat :: BigNat# -> Int# -> BigNat#
shiftRBigNat bn i = BN.bigNatShiftR# bn (int2Word# i)
testBitBigNat :: BigNat# -> Int# -> Bool
testBitBigNat bn i = isTrue# (BN.bigNatTestBit# bn (int2Word# i))
-- Top-level unlifted bindings aren't allowed, so we fake one with a thunk.
zeroBigNat :: (# #) -> BigNat#
zeroBigNat _ = BN.bigNatFromWord# 0##
#else
import GHC.Integer.GMP.Internals (bigNatToInteger, recipModBigNat, shiftLBigNat, shiftRBigNat, testBitBigNat)
import qualified GHC.Integer.GMP.Internals as GMP
import GHC.Exts
type BigNat# = GMP.BigNat
{-# COMPLETE IS, IP, IN #-}
pattern IS :: Int# -> Integer
pattern IS i = GMP.S# i
pattern IP :: ByteArray# -> Integer
pattern IP ba = GMP.Jp# (GMP.BN# ba)
pattern IN :: ByteArray# -> Integer
pattern IN ba = GMP.Jn# (GMP.BN# ba)
bigNatAdd :: BigNat# -> BigNat# -> BigNat#
bigNatAdd = GMP.plusBigNat
bigNatIsOne :: BigNat# -> Bool
bigNatIsOne bn = GMP.eqBigNat bn GMP.oneBigNat
bigNatIsZero :: BigNat# -> Bool
bigNatIsZero = GMP.isZeroBigNat
bigNatMul :: BigNat# -> BigNat# -> BigNat#
bigNatMul = GMP.timesBigNat
bigNatRem :: BigNat# -> BigNat# -> BigNat#
bigNatRem = GMP.remBigNat
bigNatSqr :: BigNat# -> BigNat#
bigNatSqr = GMP.sqrBigNat
bigNatSub :: BigNat# -> BigNat# -> (# (# #) | BigNat# #)
bigNatSub x y =
case GMP.isNullBigNat# res of
0# -> (# | res #)
_ -> (# (# #) | #)
where
res = GMP.minusBigNat x y
bigNatSubUnsafe :: BigNat# -> BigNat# -> BigNat#
bigNatSubUnsafe = GMP.minusBigNat
integerToBigNat :: Integer -> BigNat#
integerToBigNat (GMP.S# i) = GMP.wordToBigNat (int2Word# i)
integerToBigNat (GMP.Jp# b) = b
integerToBigNat (GMP.Jn# b) = b
integerRecipMod :: Integer -> Integer -> (# Integer | () #)
integerRecipMod x y
| res == 0 = (# | () #)
| otherwise = (# res | #)
where
res = GMP.recipModInteger x y
oneBigNat :: (##) -> BigNat#
oneBigNat _ = GMP.oneBigNat
zeroBigNat :: (##) -> BigNat#
zeroBigNat _ = GMP.zeroBigNat
#endif