Initial implementation of type Z n (integers mod n).

This covers most of #510, including the type itself and the
class instances, but not any of the new primitive functions.
This commit is contained in:
Brian Huffman 2018-06-11 18:03:18 -07:00
parent e11806f64a
commit 7da2219caf
11 changed files with 183 additions and 34 deletions

View File

@ -329,6 +329,7 @@ etaDelay msg env0 Forall{ sVars = vs0, sType = tp0 } = goTpVars env0 vs0
case tp of
TVBit -> x
TVInteger -> x
TVIntMod _ -> x
TVSeq n TVBit ->
do w <- delayFill (fromWordVal "during eta-expansion" =<< x) (etaWord n x)

View File

@ -169,6 +169,7 @@ cpo that represents any given schema.
> case ty of
> TVBit -> VBit (fromVBit val)
> TVInteger -> VInteger (fromVInteger val)
> TVIntMod _ -> VInteger (fromVInteger val)
> TVSeq w ety -> VList (map (go ety) (copyList w (fromVList val)))
> TVStream ety -> VList (map (go ety) (copyStream (fromVList val)))
> TVTuple etys -> VTuple (zipWith go etys (copyList (genericLength etys) (fromVTuple val)))
@ -713,6 +714,12 @@ output bitvector will contain the exception in all bit positions.
> vWord w e = VList [ VBit (fmap (test i) e) | i <- [w-1, w-2 .. 0] ]
> where test i x = testBit x (fromInteger i)
Functions returning type `Z n` require normalizing the integer result
modulo `n`. If `n` is `0` or `inf`, then the result is unchanged.
> modulo :: Nat' -> Integer -> Integer
> modulo (Nat n) x = if n > 0 then x `mod` n else x
> modulo Inf x = x
Logic
-----
@ -729,6 +736,7 @@ at the same positions.
> where
> go TVBit = VBit b
> go TVInteger = VInteger (fmap (\c -> if c then -1 else 0) b)
> go (TVIntMod _) = VInteger (fmap (const 0) b)
> go (TVSeq n ety) = VList (genericReplicate n (go ety))
> go (TVStream ety) = VList (repeat (go ety))
> go (TVTuple tys) = VTuple (map go tys)
@ -743,6 +751,7 @@ at the same positions.
> case ty of
> TVBit -> VBit (fmap op (fromVBit val))
> TVInteger -> evalPanic "logicUnary" ["Integer not in class Logic"]
> TVIntMod _ -> evalPanic "logicUnary" ["Z not in class Logic"]
> TVSeq _w ety -> VList (map (go ety) (fromVList val))
> TVStream ety -> VList (map (go ety) (fromVList val))
> TVTuple etys -> VTuple (zipWith go etys (fromVTuple val))
@ -757,6 +766,7 @@ at the same positions.
> case ty of
> TVBit -> VBit (liftA2 op (fromVBit l) (fromVBit r))
> TVInteger -> evalPanic "logicBinary" ["Integer not in class Logic"]
> TVIntMod _ -> evalPanic "logicBinary" ["Z not in class Logic"]
> TVSeq _w ety -> VList (zipWith (go ety) (fromVList l) (fromVList r))
> TVStream ety -> VList (zipWith (go ety) (fromVList l) (fromVList r))
> TVTuple etys -> VTuple (zipWith3 go etys (fromVTuple l) (fromVTuple r))
@ -788,6 +798,8 @@ up of non-empty finite bitvectors.
> evalPanic "arithUnary" ["Bit not in class Arith"]
> TVInteger ->
> VInteger (op <$> fromVInteger val)
> TVIntMod n' ->
> VInteger (modulo n' <$> op <$> fromVInteger val)
> TVSeq w a
> | isTBit a -> vWord w (op <$> fromVWord val)
> | otherwise -> VList (map (go a) (fromVList val))
@ -826,6 +838,14 @@ up of non-empty finite bitvectors.
> case fromVInteger r of
> Left e -> Left e
> Right j -> op i j
> TVIntMod n' ->
> VInteger $
> case fromVInteger l of
> Left e -> Left e
> Right i ->
> case fromVInteger r of
> Left e -> Left e
> Right j -> modulo n' <$> op i j
> TVSeq w a
> | isTBit a -> vWord w $
> case fromWord l of
@ -887,6 +907,8 @@ bits to the *left* of that position are equal.
> compare <$> fromVBit l <*> fromVBit r
> TVInteger ->
> compare <$> fromVInteger l <*> fromVInteger r
> TVIntMod _ ->
> compare <$> fromVInteger l <*> fromVInteger r
> TVSeq _w ety ->
> lexList (zipWith (lexCompare ety) (fromVList l) (fromVList r))
> TVStream _ ->
@ -926,6 +948,8 @@ fields are compared in alphabetical order.
> evalPanic "lexSignedCompare" ["invalid type"]
> TVInteger ->
> evalPanic "lexSignedCompare" ["invalid type"]
> TVIntMod _ ->
> evalPanic "lexSignedCompare" ["invalid type"]
> TVSeq _w ety
> | isTBit ety ->
> case fromSignedVWord l of

View File

@ -29,6 +29,7 @@ import Control.DeepSeq
data TValue
= TVBit -- ^ @ Bit @
| TVInteger -- ^ @ Integer @
| TVIntMod Nat' -- ^ @ Z n @
| TVSeq Integer TValue -- ^ @ [n]a @
| TVStream TValue -- ^ @ [inf]t @
| TVTuple [TValue] -- ^ @ (a, b, c )@
@ -42,6 +43,7 @@ tValTy tv =
case tv of
TVBit -> tBit
TVInteger -> tInteger
TVIntMod n -> tIntMod (tNat' n)
TVSeq n t -> tSeq (tNum n) (tValTy t)
TVStream t -> tSeq tInf (tValTy t)
TVTuple ts -> tTuple (map tValTy ts)
@ -93,6 +95,7 @@ evalType env ty =
case (c, ts) of
(TCBit, []) -> Right $ TVBit
(TCInteger, []) -> Right $ TVInteger
(TCIntMod, [n]) -> Right $ TVIntMod (num n)
(TCSeq, [n, t]) -> Right $ tvSeq (num n) (val t)
(TCFun, [a, b]) -> Right $ TVFun (val a) (val b)
(TCTuple _, _) -> Right $ TVTuple (map val ts)

View File

@ -283,7 +283,7 @@ data GenValue b w i
= VRecord ![(Ident, Eval (GenValue b w i))] -- ^ @ { .. } @
| VTuple ![Eval (GenValue b w i)] -- ^ @ ( .. ) @
| VBit !b -- ^ @ Bit @
| VInteger !i -- ^ @ Integer @
| VInteger !i -- ^ @ Integer @ or @ Z n @
| VSeq !Integer !(SeqMap b w i) -- ^ @ [n]a @
-- Invariant: VSeq is never a sequence of bits
| VWord !Integer !(Eval (WordValue b w i)) -- ^ @ [n]Bit @
@ -781,6 +781,8 @@ toExpr prims t0 v0 = findOne (go t0 v0)
(TCon (TC TCBit) [], VBit False) -> return (prim "False")
(TCon (TC TCInteger) [], VInteger i) ->
return $ ETApp (prim "integer") (tNum i)
(TCon (TC TCIntMod) [_n], VInteger i) ->
return $ ETApp (prim "integer") (tNum i) --FIXME
(TCon (TC TCSeq) [a,b], VSeq 0 _) -> do
guard (a == tZero)
return $ EList [] b

View File

@ -344,6 +344,7 @@ tconNames :: Map.Map PName TC
tconNames = Map.fromList
[ (mkUnqual (packIdent "Bit"), TCBit)
, (mkUnqual (packIdent "Integer"), TCInteger)
, (mkUnqual (packIdent "Z"), TCIntMod)
, (mkUnqual (packIdent "inf"), TCInf)
]

View File

@ -61,21 +61,26 @@ instance EvalPrims Bool BV Integer where
primTable :: Map.Map Ident Value
primTable = Map.fromList $ map (\(n, v) -> (mkIdent (T.pack n), v))
[ ("+" , {-# SCC "Prelude::(+)" #-}
binary (arithBinary (liftBinArith (+)) (liftBinInteger (+))))
binary (arithBinary (liftBinArith (+)) (liftBinInteger (+))
(liftBinIntMod (+))))
, ("-" , {-# SCC "Prelude::(-)" #-}
binary (arithBinary (liftBinArith (-)) (liftBinInteger (-))))
binary (arithBinary (liftBinArith (-)) (liftBinInteger (-))
(liftBinIntMod (-))))
, ("*" , {-# SCC "Prelude::(*)" #-}
binary (arithBinary (liftBinArith (*)) (liftBinInteger (*))))
binary (arithBinary (liftBinArith (*)) (liftBinInteger (*))
(liftBinIntMod (*))))
, ("/" , {-# SCC "Prelude::(/)" #-}
binary (arithBinary (liftDivArith div) (liftDivInteger div)))
binary (arithBinary (liftDivArith div) (liftDivInteger div)
(const (liftDivInteger div))))
, ("%" , {-# SCC "Prelude::(%)" #-}
binary (arithBinary (liftDivArith mod) (liftDivInteger mod)))
binary (arithBinary (liftDivArith mod) (liftDivInteger mod)
(const (liftDivInteger mod))))
, ("^^" , {-# SCC "Prelude::(^^)" #-}
binary (arithBinary modExp integerExp))
binary (arithBinary modExp integerExp intModExp))
, ("lg2" , {-# SCC "Prelude::lg2" #-}
unary (arithUnary (liftUnaryArith lg2) lg2))
unary (arithUnary (liftUnaryArith lg2) lg2 (const . lg2)))
, ("negate" , {-# SCC "Prelude::negate" #-}
unary (arithUnary (liftUnaryArith negate) negate))
unary (arithUnary (liftUnaryArith negate) negate (const . negate)))
, ("<" , {-# SCC "Prelude::(<)" #-}
binary (cmpOrder "<" (\o -> o == LT )))
, (">" , {-# SCC "Prelude::(>)" #-}
@ -91,9 +96,11 @@ primTable = Map.fromList $ map (\(n, v) -> (mkIdent (T.pack n), v))
, ("<$" , {-# SCC "Prelude::(<$)" #-}
binary (signedCmpOrder "<$" (\o -> o == LT)))
, ("/$" , {-# SCC "Prelude::(/$)" #-}
binary (arithBinary (liftSigned bvSdiv) (liftDivInteger div)))
binary (arithBinary (liftSigned bvSdiv) (liftDivInteger div)
(const (liftDivInteger div))))
, ("%$" , {-# SCC "Prelude::(%$)" #-}
binary (arithBinary (liftSigned bvSrem) (liftDivInteger mod)))
binary (arithBinary (liftSigned bvSrem) (liftDivInteger mod)
(const (liftDivInteger mod))))
, (">>$" , {-# SCC "Prelude::(>>$)" #-}
sshrV)
, ("&&" , {-# SCC "Prelude::(&&)" #-}
@ -268,6 +275,12 @@ modExp bits (BV _ base) (BV _ e)
where
modulus = 0 `setBit` fromInteger bits
intModExp :: Integer -> Integer -> Integer -> Eval Integer
intModExp modulus base e
| modulus > 0 = ready $ doubleAndAdd base e modulus
| modulus == 0 = integerExp base e
| otherwise = evalPanic "intModExp" [ "negative modulus: " ++ show modulus ]
integerExp :: Integer -> Integer -> Eval Integer
integerExp x y
| y < 0 = negativeExponent
@ -333,6 +346,12 @@ type BinArith w = Integer -> w -> w -> Eval w
liftBinInteger :: (Integer -> Integer -> Integer) -> Integer -> Integer -> Eval Integer
liftBinInteger op x y = ready $ op x y
liftBinIntMod ::
(Integer -> Integer -> Integer) -> Integer -> Integer -> Integer -> Eval Integer
liftBinIntMod op m x y
| m == 0 = ready $ op x y
| otherwise = ready $ (op x y) `mod` m
liftDivInteger :: (Integer -> Integer -> Integer) -> Integer -> Integer -> Eval Integer
liftDivInteger _ _ 0 = divideByZero
liftDivInteger op x y = ready $ op x y
@ -345,8 +364,9 @@ arithBinary :: forall b w i
. BitWord b w i
=> BinArith w
-> (i -> i -> Eval i)
-> (Integer -> i -> i -> Eval i)
-> Binary b w i
arithBinary opw opi = loop
arithBinary opw opi opz = loop
where
loop' :: TValue
-> Eval (GenValue b w i)
@ -365,6 +385,13 @@ arithBinary opw opi = loop
TVInteger ->
VInteger <$> opi (fromVInteger l) (fromVInteger r)
TVIntMod n' ->
case n' of
Nat n ->
VInteger <$> opz n (fromVInteger l) (fromVInteger r)
Inf ->
VInteger <$> opi (fromVInteger l) (fromVInteger r)
TVSeq w a
-- words and finite sequences
| isTBit a -> do
@ -408,8 +435,9 @@ arithUnary :: forall b w i
. BitWord b w i
=> UnaryArith w
-> (i -> i)
-> (Integer -> i -> i)
-> Unary b w i
arithUnary opw opi = loop
arithUnary opw opi opz = loop
where
loop' :: TValue -> Eval (GenValue b w i) -> Eval (GenValue b w i)
loop' ty x = loop ty =<< x
@ -423,6 +451,13 @@ arithUnary opw opi = loop
TVInteger ->
return $ VInteger $ opi (fromVInteger x)
TVIntMod n' ->
case n' of
Nat n ->
return $ VInteger $ opz n (fromVInteger x)
Inf ->
return $ VInteger $ opi (fromVInteger x)
TVSeq w a
-- words and finite sequences
| isTBit a -> do
@ -465,13 +500,17 @@ cmpValue :: BitWord b w i
=> (b -> b -> Eval a -> Eval a)
-> (w -> w -> Eval a -> Eval a)
-> (i -> i -> Eval a -> Eval a)
-> (Integer -> i -> i -> Eval a -> Eval a)
-> (TValue -> GenValue b w i -> GenValue b w i -> Eval a -> Eval a)
cmpValue fb fw fi = cmp
cmpValue fb fw fi fz = cmp
where
cmp ty v1 v2 k =
case ty of
TVBit -> fb (fromVBit v1) (fromVBit v2) k
TVInteger -> fi (fromVInteger v1) (fromVInteger v2) k
TVIntMod n' -> case n' of
Nat n -> fz n (fromVInteger v1) (fromVInteger v2) k
Inf -> fi (fromVInteger v1) (fromVInteger v2) k
TVSeq n t
| isTBit t -> do w1 <- fromVWord "cmpValue" v1
w2 <- fromVWord "cmpValue" v2
@ -498,7 +537,7 @@ cmpValue fb fw fi = cmp
lexCompare :: TValue -> Value -> Value -> Eval Ordering
lexCompare ty a b = cmpValue op opw op ty a b (return EQ)
lexCompare ty a b = cmpValue op opw op (const op) ty a b (return EQ)
where
opw :: BV -> BV -> Eval Ordering -> Eval Ordering
opw x y k = op (bvVal x) (bvVal y) k
@ -509,7 +548,7 @@ lexCompare ty a b = cmpValue op opw op ty a b (return EQ)
cmp -> return cmp
signedLexCompare :: TValue -> Value -> Value -> Eval Ordering
signedLexCompare ty a b = cmpValue opb opw opi ty a b (return EQ)
signedLexCompare ty a b = cmpValue opb opw opi (const opi) ty a b (return EQ)
where
opb :: Bool -> Bool -> Eval Ordering -> Eval Ordering
opb _x _y _k = panic "signedLexCompare"
@ -627,6 +666,10 @@ zeroV ty = case ty of
TVInteger ->
VInteger (integerLit 0)
-- integers mod n
TVIntMod _ ->
VInteger (integerLit 0)
-- sequences
TVSeq w ety
| isTBit ety -> word w 0
@ -987,6 +1030,7 @@ logicBinary opb opw = loop
loop ty l r = case ty of
TVBit -> return $ VBit (opb (fromVBit l) (fromVBit r))
TVInteger -> evalPanic "logicBinary" ["Integer not in class Logic"]
TVIntMod _ -> evalPanic "logicBinary" ["Z not in class Logic"]
TVSeq w aty
-- words
| isTBit aty
@ -1050,6 +1094,7 @@ logicUnary opb opw = loop
TVBit -> return . VBit . opb $ fromVBit val
TVInteger -> evalPanic "logicUnary" ["Integer not in class Logic"]
TVIntMod _ -> evalPanic "logicUnary" ["Z not in class Logic"]
TVSeq w ety
-- words
@ -1408,6 +1453,7 @@ errorV ty msg = case ty of
-- bits
TVBit -> cryUserError msg
TVInteger -> cryUserError msg
TVIntMod _ -> cryUserError msg
-- sequences
TVSeq w ety

View File

@ -40,6 +40,7 @@ import qualified Cryptol.Eval.Type as Eval
import qualified Cryptol.Eval.Value as Eval
import Cryptol.Eval.Env (GenEvalEnv(..))
import Cryptol.TypeCheck.AST
import Cryptol.TypeCheck.Solver.InfNat(Nat'(..))
import Cryptol.Utils.Ident (Ident)
import Cryptol.Utils.PP
import Cryptol.Utils.Panic(panic)
@ -271,6 +272,7 @@ parseValue FTInteger cws =
case SBV.genParse SBV.KUnbounded cws of
Just (x, cws') -> (Eval.VInteger x, cws')
Nothing -> panic "Cryptol.Symbolic.parseValue" [ "no integer" ]
parseValue (FTIntMod _) cws = parseValue FTInteger cws
parseValue (FTSeq 0 FTBit) cws = (Eval.word 0 0, cws)
parseValue (FTSeq n FTBit) cws =
case SBV.genParse (SBV.KBounded False n) cws of
@ -295,6 +297,7 @@ allDeclGroups = concatMap mDecls . M.loadedModules
data FinType
= FTBit
| FTInteger
| FTIntMod Integer
| FTSeq Int FinType
| FTTuple [FinType]
| FTRecord [(Ident, FinType)]
@ -309,6 +312,9 @@ finType ty =
case ty of
Eval.TVBit -> Just FTBit
Eval.TVInteger -> Just FTInteger
Eval.TVIntMod n' -> case n' of
Nat n -> Just (FTIntMod n)
Inf -> Just FTInteger
Eval.TVSeq n t -> FTSeq <$> numType n <*> finType t
Eval.TVTuple ts -> FTTuple <$> traverse finType ts
Eval.TVRec fields -> FTRecord <$> traverse (traverseSnd finType) fields
@ -319,6 +325,7 @@ unFinType fty =
case fty of
FTBit -> tBit
FTInteger -> tInteger
FTIntMod n -> tIntMod (tNum n)
FTSeq l ety -> tSeq (tNum l) (unFinType ety)
FTTuple ftys -> tTuple (unFinType <$> ftys)
FTRecord fs -> tRec (zip fns tys)
@ -344,6 +351,7 @@ forallFinType ty =
case ty of
FTBit -> VBit <$> forallSBool_
FTInteger -> VInteger <$> forallSInteger_
FTIntMod _ -> VInteger <$> forallSInteger_
FTSeq 0 FTBit -> return $ Eval.word 0 0
FTSeq n FTBit -> VWord (toInteger n) . return . Eval.WordVal <$> (forallBV_ n)
FTSeq n t -> do vs <- replicateM n (forallFinType t)
@ -356,6 +364,7 @@ existsFinType ty =
case ty of
FTBit -> VBit <$> existsSBool_
FTInteger -> VInteger <$> existsSInteger_
FTIntMod _ -> VInteger <$> existsSInteger_
FTSeq 0 FTBit -> return $ Eval.word 0 0
FTSeq n FTBit -> VWord (toInteger n) . return . Eval.WordVal <$> (existsBV_ n)
FTSeq n t -> do vs <- replicateM n (existsFinType t)

View File

@ -81,25 +81,34 @@ primTable = Map.fromList $ map (\(n, v) -> (mkIdent (T.pack n), v))
-- { val, bits } (fin val, fin bits, bits >= width val) => [bits]
, ("integer" , ecIntegerV) -- Converts a numeric type into its corresponding value.
-- { val } (fin val) => Integer
, ("+" , binary (arithBinary (liftBinArith SBV.svPlus) (liftBin SBV.svPlus))) -- {a} (Arith a) => a -> a -> a
, ("-" , binary (arithBinary (liftBinArith SBV.svMinus) (liftBin SBV.svMinus))) -- {a} (Arith a) => a -> a -> a
, ("*" , binary (arithBinary (liftBinArith SBV.svTimes) (liftBin SBV.svTimes))) -- {a} (Arith a) => a -> a -> a
, ("/" , binary (arithBinary (liftBinArith SBV.svQuot) (liftBin SBV.svQuot))) -- {a} (Arith a) => a -> a -> a
, ("%" , binary (arithBinary (liftBinArith SBV.svRem) (liftBin SBV.svRem))) -- {a} (Arith a) => a -> a -> a
, ("^^" , binary (arithBinary sExp (liftBin SBV.svExp))) -- {a} (Arith a) => a -> a -> a
, ("lg2" , unary (arithUnary sLg2 svLg2)) -- {a} (Arith a) => a -> a
, ("negate" , unary (arithUnary (\_ -> ready . SBV.svUNeg) SBV.svUNeg))
, ("<" , binary (cmpBinary cmpLt cmpLt cmpLt SBV.svFalse))
, (">" , binary (cmpBinary cmpGt cmpGt cmpGt SBV.svFalse))
, ("<=" , binary (cmpBinary cmpLtEq cmpLtEq cmpLtEq SBV.svTrue))
, (">=" , binary (cmpBinary cmpGtEq cmpGtEq cmpGtEq SBV.svTrue))
, ("==" , binary (cmpBinary cmpEq cmpEq cmpEq SBV.svTrue))
, ("!=" , binary (cmpBinary cmpNotEq cmpNotEq cmpNotEq SBV.svFalse))
, ("+" , binary (arithBinary (liftBinArith SBV.svPlus) (liftBin SBV.svPlus)
(const (liftBin SBV.svPlus)))) -- {a} (Arith a) => a -> a -> a
, ("-" , binary (arithBinary (liftBinArith SBV.svMinus) (liftBin SBV.svMinus)
(const (liftBin SBV.svMinus)))) -- {a} (Arith a) => a -> a -> a
, ("*" , binary (arithBinary (liftBinArith SBV.svTimes) (liftBin SBV.svTimes)
(const (liftBin SBV.svTimes)))) -- {a} (Arith a) => a -> a -> a
, ("/" , binary (arithBinary (liftBinArith SBV.svQuot) (liftBin SBV.svQuot)
(liftModBin SBV.svQuot))) -- {a} (Arith a) => a -> a -> a
, ("%" , binary (arithBinary (liftBinArith SBV.svRem) (liftBin SBV.svRem)
(liftModBin SBV.svRem))) -- {a} (Arith a) => a -> a -> a
, ("^^" , binary (arithBinary sExp (liftBin SBV.svExp)
(liftModBin SBV.svRem))) -- {a} (Arith a) => a -> a -> a
, ("lg2" , unary (arithUnary sLg2 svLg2 svModLg2)) -- {a} (Arith a) => a -> a
, ("negate" , unary (arithUnary (\_ -> ready . SBV.svUNeg) SBV.svUNeg
(const SBV.svUNeg)))
, ("<" , binary (cmpBinary cmpLt cmpLt cmpLt (cmpMod cmpLt) SBV.svFalse))
, (">" , binary (cmpBinary cmpGt cmpGt cmpGt (cmpMod cmpGt) SBV.svFalse))
, ("<=" , binary (cmpBinary cmpLtEq cmpLtEq cmpLtEq (cmpMod cmpLtEq) SBV.svTrue))
, (">=" , binary (cmpBinary cmpGtEq cmpGtEq cmpGtEq (cmpMod cmpGtEq) SBV.svTrue))
, ("==" , binary (cmpBinary cmpEq cmpEq cmpEq cmpModEq SBV.svTrue))
, ("!=" , binary (cmpBinary cmpNotEq cmpNotEq cmpNotEq cmpModNotEq SBV.svFalse))
, ("<$" , let boolFail = evalPanic "<$" ["Attempted signed comparison on bare Bit values"]
intFail = evalPanic "<$" ["Attempted signed comparison on Integer values"]
in binary (cmpBinary boolFail cmpSignedLt intFail SBV.svFalse))
, ("/$" , binary (arithBinary (liftBinArith signedQuot) (liftBin SBV.svQuot)))
, ("%$" , binary (arithBinary (liftBinArith signedRem) (liftBin SBV.svRem)))
in binary (cmpBinary boolFail cmpSignedLt intFail (const intFail) SBV.svFalse))
, ("/$" , binary (arithBinary (liftBinArith signedQuot) (liftBin SBV.svQuot)
(liftModBin SBV.svQuot))) -- {a} (Arith a) => a -> a -> a
, ("%$" , binary (arithBinary (liftBinArith signedRem) (liftBin SBV.svRem)
(liftModBin SBV.svRem)))
, (">>$" , sshrV)
, ("&&" , binary (logicBinary SBV.svAnd SBV.svAnd))
, ("||" , binary (logicBinary SBV.svOr SBV.svOr))
@ -448,6 +457,10 @@ liftBinArith op _ x y = ready $ op x y
liftBin :: (a -> b -> c) -> a -> b -> Eval c
liftBin op x y = ready $ op x y
liftModBin :: (SInteger -> SInteger -> a) -> Integer -> SInteger -> SInteger -> Eval a
liftModBin op modulus x y = ready $ op (SBV.svRem x m) (SBV.svRem y m)
where m = integerLit modulus
sExp :: Integer -> SWord -> SWord -> Eval SWord
sExp _w x y = ready $ go (reverse (unpackWord y)) -- bits in little-endian order
where go [] = literalSWord (SBV.intSizeOf x) 1
@ -470,6 +483,10 @@ svLg2 x =
Just n -> SBV.svInteger SBV.KUnbounded (lg2 n)
Nothing -> evalPanic "cannot compute lg2 of symbolic unbounded integer" []
svModLg2 :: Integer -> SInteger -> SInteger
svModLg2 modulus x = svLg2 (SBV.svRem x m)
where m = integerLit modulus
-- Cmp -------------------------------------------------------------------------
cmpEq :: SWord -> SWord -> Eval SBool -> Eval SBool
@ -491,11 +508,27 @@ cmpLtEq, cmpGtEq :: SWord -> SWord -> Eval SBool -> Eval SBool
cmpLtEq x y k = SBV.svAnd (SBV.svLessEq x y) <$> (cmpNotEq x y k)
cmpGtEq x y k = SBV.svAnd (SBV.svGreaterEq x y) <$> (cmpNotEq x y k)
cmpMod ::
(SInteger -> SInteger -> Eval SBool -> Eval SBool) ->
(Integer -> SInteger -> SInteger -> Eval SBool -> Eval SBool)
cmpMod cmp modulus x y k = cmp (SBV.svRem x m) (SBV.svRem y m) k
where m = integerLit modulus
cmpModEq :: Integer -> SInteger -> SInteger -> Eval SBool -> Eval SBool
cmpModEq m x y k = SBV.svAnd (svDivisible m (SBV.svMinus x y)) <$> k
cmpModNotEq :: Integer -> SInteger -> SInteger -> Eval SBool -> Eval SBool
cmpModNotEq m x y k = SBV.svOr (SBV.svNot (svDivisible m (SBV.svMinus x y))) <$> k
svDivisible :: Integer -> SInteger -> SBool
svDivisible m x = SBV.svEqual (SBV.svRem x (integerLit m)) (integerLit 0)
cmpBinary :: (SBool -> SBool -> Eval SBool -> Eval SBool)
-> (SWord -> SWord -> Eval SBool -> Eval SBool)
-> (SInteger -> SInteger -> Eval SBool -> Eval SBool)
-> (Integer -> SInteger -> SInteger -> Eval SBool -> Eval SBool)
-> SBool -> Binary SBool SWord SInteger
cmpBinary fb fw fi b ty v1 v2 = VBit <$> cmpValue fb fw fi ty v1 v2 (return b)
cmpBinary fb fw fi fz b ty v1 v2 = VBit <$> cmpValue fb fw fi fz ty v1 v2 (return b)
-- Signed arithmetic -----------------------------------------------------------

View File

@ -89,6 +89,10 @@ typeSize ty =
(TCInf, _) -> Nothing
(TCBit, _) -> Just 2
(TCInteger, _) -> Nothing
(TCIntMod, [sz]) -> case tNoUser sz of
TCon (TC (TCNum n)) _ -> Just n
_ -> Nothing
(TCIntMod, _) -> Nothing
(TCSeq, [sz,el]) -> case tNoUser sz of
TCon (TC (TCNum n)) _ -> (^ n) <$> typeSize el
_ -> Nothing
@ -117,6 +121,11 @@ typeValues ty =
TCInf -> []
TCBit -> [ VBit False, VBit True ]
TCInteger -> []
TCIntMod ->
case map tNoUser ts of
[ TCon (TC (TCNum n)) _ ] | 0 < n ->
[ VInteger x | x <- [ 0 .. n - 1 ] ]
_ -> []
TCSeq ->
case map tNoUser ts of
[ TCon (TC (TCNum n)) _, TCon (TC TCBit) [] ] ->

View File

@ -45,6 +45,9 @@ solveZeroInst ty = case tNoUser ty of
-- Zero Integer
TCon (TC TCInteger) [] -> SolvedIf []
-- Zero (Z n)
TCon (TC TCIntMod) [_] -> SolvedIf []
-- Zero a => Zero [n]a
TCon (TC TCSeq) [_, a] -> SolvedIf [ pZero a ]
@ -106,6 +109,9 @@ solveArithInst ty = case tNoUser ty of
-- Arith Integer
TCon (TC TCInteger) [] -> SolvedIf []
-- Arith (Z n)
TCon (TC TCIntMod) [_] -> SolvedIf []
-- (Arith a, Arith b) => Arith { x1 : a, x2 : b }
TRec fs -> SolvedIf [ pArith ety | (_,ety) <- fs ]
@ -139,6 +145,9 @@ solveCmpInst ty = case tNoUser ty of
-- Cmp Integer
TCon (TC TCInteger) [] -> SolvedIf []
-- Cmp (Z n)
TCon (TC TCIntMod) [_] -> SolvedIf []
-- (fin n, Cmp a) => Cmp [n]a
TCon (TC TCSeq) [n,a] -> SolvedIf [ pFin n, pCmp a ]

View File

@ -153,6 +153,7 @@ data TC = TCNum Integer -- ^ Numbers
| TCInf -- ^ Inf
| TCBit -- ^ Bit
| TCInteger -- ^ Integer
| TCIntMod -- ^ @Z _@
| TCSeq -- ^ @[_] _@
| TCFun -- ^ @_ -> _@
| TCTuple Int -- ^ @(_, _, _)@
@ -222,6 +223,7 @@ instance HasKind TC where
TCInf -> KNum
TCBit -> KType
TCInteger -> KType
TCIntMod -> KNum :-> KType
TCSeq -> KNum :-> KType :-> KType
TCFun -> KType :-> KType :-> KType
TCTuple n -> foldr (:->) KType (replicate n KType)
@ -421,6 +423,11 @@ tIsInteger ty = case tNoUser ty of
TCon (TC TCInteger) [] -> True
_ -> False
tIsIntMod :: Type -> Maybe Type
tIsIntMod ty = case tNoUser ty of
TCon (TC TCIntMod) [n] -> Just n
_ -> Nothing
tIsTuple :: Type -> Maybe [Type]
tIsTuple ty = case tNoUser ty of
TCon (TC (TCTuple _)) ts -> Just ts
@ -523,6 +530,9 @@ tBit = TCon (TC TCBit) []
tInteger :: Type
tInteger = TCon (TC TCInteger) []
tIntMod :: Type -> Type
tIntMod n = TCon (TC TCIntMod) [n]
tWord :: Type -> Type
tWord a = tSeq a tBit
@ -811,6 +821,7 @@ instance PP (WithNames Type) where
(TCInf, []) -> text "inf"
(TCBit, []) -> text "Bit"
(TCInteger, []) -> text "Integer"
(TCIntMod, [n]) -> optParens (prec > 3) $ text "Z" <+> go 4 n
(TCSeq, [t1,TCon (TC TCBit) []]) -> brackets (go 0 t1)
(TCSeq, [t1,t2]) -> optParens (prec > 3)
@ -932,6 +943,7 @@ instance PP TC where
TCInf -> text "inf"
TCBit -> text "Bit"
TCInteger -> text "Integer"
TCIntMod -> text "Z"
TCSeq -> text "[]"
TCFun -> text "(->)"
TCTuple 0 -> text "()"