Merge branch 'master' of github.com:GaloisInc/cryptol

This commit is contained in:
Iavor S. Diatchki 2017-02-23 15:22:40 -08:00
commit fddcd60d10

View File

@ -12,7 +12,7 @@
module Cryptol.Eval.Reference where
import qualified Control.Exception as X (throw)
import Control.Applicative (liftA2)
import Control.Monad (foldM)
import Data.Bits
import Data.List
@ -26,8 +26,8 @@ import Cryptol.ModuleSystem.Name (asPrim)
import Cryptol.TypeCheck.Solver.InfNat (Nat'(..))
import Cryptol.TypeCheck.AST
import Cryptol.Eval.Monad (EvalError(..))
import Cryptol.Eval.Type (TypeEnv, TValue(..), isTBit, evalValType, evalNumType)
import Cryptol.Prims.Eval (divWrap, modWrap, lg2, divModPoly)
import Cryptol.Eval.Type (TypeEnv, TValue(..), isTBit, evalValType, evalNumType, tvSeq)
import Cryptol.Prims.Eval (lg2, divModPoly)
import Cryptol.Utils.Ident (Ident, mkIdent)
import Cryptol.Utils.Panic (panic)
import Cryptol.Utils.PP
@ -53,16 +53,16 @@ evaluate expr modEnv = return (Right (evalExpr env expr, modEnv), [])
-- of a @VBit@ constructor. All other @Value@ and list constructors
-- should evaluate without error.
data Value
= VRecord [(Ident, Value)] -- ^ @ { .. } @
| VTuple [Value] -- ^ @ ( .. ) @
| VBit Bool -- ^ @ Bit @
| VList [Value] -- ^ @ [n]a @ (either finite or infinite)
| VFun (Value -> Value) -- ^ functions
| VPoly (TValue -> Value) -- ^ polymorphic values (kind *)
| VNumPoly (Nat' -> Value) -- ^ polymorphic values (kind #)
= VRecord [(Ident, Value)] -- ^ @ { .. } @
| VTuple [Value] -- ^ @ ( .. ) @
| VBit (Either EvalError Bool) -- ^ @ Bit @
| VList [Value] -- ^ @ [n]a @ (either finite or infinite)
| VFun (Value -> Value) -- ^ functions
| VPoly (TValue -> Value) -- ^ polymorphic values (kind *)
| VNumPoly (Nat' -> Value) -- ^ polymorphic values (kind #)
-- | Destructor for @VBit@.
fromVBit :: Value -> Bool
fromVBit :: Value -> Either EvalError Bool
fromVBit (VBit b) = b
fromVBit _ = evalPanic "fromVBit" ["Expected a bit"]
@ -81,6 +81,16 @@ fromVFun :: Value -> (Value -> Value)
fromVFun (VFun f) = f
fromVFun _ = evalPanic "fromVFun" ["Expected a function"]
-- | Destructor for @VPoly@.
fromVPoly :: Value -> (TValue -> Value)
fromVPoly (VPoly f) = f
fromVPoly _ = evalPanic "fromVPoly" ["Expected a polymorphic value"]
-- | Destructor for @VNumPoly@.
fromVNumPoly :: Value -> (Nat' -> Value)
fromVNumPoly (VNumPoly f) = f
fromVNumPoly _ = evalPanic "fromVNumPoly" ["Expected a polymorphic value"]
-- | Destructor for @VRecord@.
fromVRecord :: Value -> [(Ident, Value)]
fromVRecord (VRecord fs) = fs
@ -105,12 +115,20 @@ integerToBits w x = go [] w x
go bs n a = go (odd a : bs) (n - 1) $! (a `div` 2)
-- | Convert a value from a big-endian binary format to an integer.
fromVWord :: Value -> Integer
fromVWord v = bitsToInteger (map fromVBit (fromVList v))
fromVWord :: Value -> Either EvalError Integer
fromVWord v = fmap bitsToInteger (mapM fromVBit (fromVList v))
-- | Convert an integer to big-endian binary value of the specified width.
vWord :: Integer -> Integer -> Value
vWord w x = VList (map VBit (integerToBits w x))
vWordValue :: Integer -> Integer -> Value
vWordValue w x = VList (map (VBit . Right) (integerToBits w x))
-- | Create a run-time error value of bitvector type.
vWordError :: Integer -> EvalError -> Value
vWordError w e = VList (genericReplicate w (VBit (Left e)))
vWord :: Integer -> Either EvalError Integer -> Value
vWord w (Left e) = vWordError w e
vWord w (Right x) = vWordValue w x
vFinPoly :: (Integer -> Value) -> Value
vFinPoly f = VNumPoly g
@ -118,6 +136,25 @@ vFinPoly f = VNumPoly g
g (Nat n) = f n
g Inf = evalPanic "vFinPoly" ["Expected finite numeric type"]
-- Conditionals ----------------------------------------------------------------
condBit :: Either e Bool -> Either e a -> Either e a -> Either e a
condBit (Left e) _ _ = Left e
condBit (Right b) x y = if b then x else y
condValue :: Either EvalError Bool -> Value -> Value -> Value
condValue c l r =
case l of
VRecord fs -> VRecord [ (f, condValue c v (lookupRecord f r)) | (f, v) <- fs ]
VTuple vs -> VTuple (zipWith (condValue c) vs (fromVList r))
VBit b -> VBit (condBit c b (fromVBit r))
VList vs -> VList (zipWith (condValue c) vs (fromVList r))
VFun f -> VFun (\v -> condValue c (f v) (fromVFun r v))
VPoly f -> VPoly (\t -> condValue c (f t) (fromVPoly r t))
VNumPoly f -> VNumPoly (\n -> condValue c (f n) (fromVNumPoly r n))
-- Environments ----------------------------------------------------------------
-- | Evaluation environment.
@ -158,11 +195,7 @@ evalExpr env expr =
ETuple es -> VTuple [ evalExpr env e | e <- es ]
ERec fields -> VRecord [ (f, evalExpr env e) | (f, e) <- fields ]
ESel e sel -> evalSel (evalExpr env e) sel
EIf c t f -> evalExpr env (if fromVBit (evalExpr env c) then t else f)
-- FIXME: this produces an invalid result if evaluation of the
-- condition yields a run-time error or fails to terminate. We
-- should use a zip-like function to push the conditionals down
-- into the leaf bits.
EIf c t f -> condValue (fromVBit (evalExpr env c)) (evalExpr env t) (evalExpr env f)
EComp _n _ty h gs ->
evalComp env h gs
@ -222,7 +255,7 @@ evalSel val sel =
case v of
VList vs -> vs !! n
_ -> evalPanic "evalSel"
[ "Unexpected value in list selection" ]
["Unexpected value in list selection."]
-- List Comprehensions ---------------------------------------------------------
@ -339,9 +372,9 @@ evalPrim n
primTable :: Map.Map Ident Value
primTable = Map.fromList $ map (\(n, v) -> (mkIdent (T.pack n), v))
[ ("+" , binary (arithBinary (const (+))))
, ("-" , binary (arithBinary (const (-))))
, ("*" , binary (arithBinary (const (*))))
[ ("+" , binary (arithBinary (\_ x y -> Right (x + y))))
, ("-" , binary (arithBinary (\_ x y -> Right (x - y))))
, ("*" , binary (arithBinary (\_ x y -> Right (x * y))))
, ("/" , binary (arithBinary (const divWrap)))
, ("%" , binary (arithBinary (const modWrap)))
-- , ("^^" , binary (arithBinary modExp))
@ -361,12 +394,12 @@ primTable = Map.fromList $ map (\(n, v) -> (mkIdent (T.pack n), v))
, (">>" , shiftV shiftRV)
, ("<<<" , rotateV rotateLV)
, (">>>" , rotateV rotateRV)
, ("True" , VBit True)
, ("False" , VBit False)
, ("True" , VBit (Right True))
, ("False" , VBit (Right False))
, ("demote" , vFinPoly $ \val ->
vFinPoly $ \bits ->
vWord bits val)
vWordValue bits val)
, ("#" , VNumPoly $ \_front ->
VNumPoly $ \_back ->
@ -382,7 +415,7 @@ primTable = Map.fromList $ map (\(n, v) -> (mkIdent (T.pack n), v))
, ("update" , updatePrim updateFront)
, ("updateEnd" , updatePrim updateBack)
, ("zero" , VPoly (logicNullary False))
, ("zero" , VPoly (logicNullary (Right False)))
, ("join" , VNumPoly $ \_parts ->
VNumPoly $ \_each ->
@ -407,32 +440,40 @@ primTable = Map.fromList $ map (\(n, v) -> (mkIdent (T.pack n), v))
vFinPoly $ \next ->
vFinPoly $ \bits ->
vFinPoly $ \len ->
VList (map (vWord bits) (genericTake len [first, next ..])))
VList (map (vWordValue bits) (genericTake len [first, next ..])))
, ("fromTo" , vFinPoly $ \first ->
vFinPoly $ \lst ->
vFinPoly $ \bits ->
VList (map (vWord bits) [first .. lst]))
VList (map (vWordValue bits) [first .. lst]))
, ("fromThenTo" , vFinPoly $ \first ->
vFinPoly $ \next ->
vFinPoly $ \_lst ->
vFinPoly $ \bits ->
vFinPoly $ \len ->
VList (map (vWord bits) (genericTake len [first, next ..])))
VList (map (vWordValue bits) (genericTake len [first, next ..])))
, ("infFrom" , vFinPoly $ \bits ->
VFun $ \first ->
VList (map (vWord bits) [fromVWord first ..]))
case fromVWord first of
Left e -> VList (repeat (vWordError bits e))
Right i -> VList (map (vWordValue bits) [i ..]))
, ("infFromThen", vFinPoly $ \bits ->
VFun $ \first ->
VFun $ \next ->
VList (map (vWord bits) [fromVWord first, fromVWord next ..]))
case fromVWord first of
Left e -> VList (repeat (vWordError bits e))
Right i ->
case fromVWord next of
Left e -> VList (repeat (vWordError bits e))
Right j -> VList (map (vWordValue bits) [i, j ..]))
, ("error" , VPoly $ \a ->
VNumPoly $ \_ ->
VFun $ \_s -> logicNullary (error "error") a)
VFun $ \_s -> logicNullary (Left (UserError "error")) a)
-- TODO: obtain error string from argument s
, ("reverse" , VNumPoly $ \_a ->
VPoly $ \_b ->
@ -450,18 +491,34 @@ primTable = Map.fromList $ map (\(n, v) -> (mkIdent (T.pack n), v))
in vFinPoly $ \a ->
vFinPoly $ \b ->
VFun $ \x ->
VFun $ \y -> vWord (1 + a + b) (mul 0 (fromVWord x) (fromVWord y) (1+b)))
VFun $ \y ->
case fromVWord x of
Left e -> vWordError (1 + a + b) e
Right i ->
case fromVWord y of
Left e -> vWordError (1 + a + b) e
Right j -> vWordValue (1 + a + b) (mul 0 i j (1+b)))
, ("pdiv" , vFinPoly $ \a ->
vFinPoly $ \b ->
VFun $ \x ->
VFun $ \y ->
vWord a (fst (divModPoly (fromVWord x) (fromInteger a) (fromVWord y) (fromInteger b))))
case fromVWord x of
Left e -> vWordError a e
Right i ->
case fromVWord y of
Left e -> vWordError a e
Right j -> vWordValue a (fst (divModPoly i (fromInteger a) j (fromInteger b))))
, ("pmod" , vFinPoly $ \a ->
vFinPoly $ \b ->
VFun $ \x ->
VFun $ \y ->
vWord b (snd (divModPoly (fromVWord x) (fromInteger a) (fromVWord y) (fromInteger b + 1))))
case fromVWord x of
Left e -> vWordError b e
Right i ->
case fromVWord y of
Left e -> vWordError b e
Right j -> vWordValue b (snd (divModPoly i (fromInteger a) j (fromInteger b + 1))))
{-
, ("random" , VPoly $ \a ->
wVFun $ \(bvVal -> x) -> return $ randomV a x)
@ -504,7 +561,7 @@ arithUnary op = go
TVBit ->
evalPanic "arithUnary" ["Bit not in class Arith"]
TVSeq w a
| isTBit a -> vWord w (op w (fromVWord val))
| isTBit a -> vWord w (op w <$> fromVWord val)
| otherwise -> VList (map (go a) (fromVList val))
TVStream a ->
VList (map (go a) (fromVList val))
@ -515,7 +572,7 @@ arithUnary op = go
TVRec fs ->
VRecord [ (f, go fty (lookupRecord f val)) | (f, fty) <- fs ]
arithBinary :: (Integer -> Integer -> Integer -> Integer)
arithBinary :: (Integer -> Integer -> Integer -> Either EvalError Integer)
-> TValue -> Value -> Value -> Value
arithBinary op = go
where
@ -525,7 +582,12 @@ arithBinary op = go
TVBit ->
evalPanic "arithBinary" ["Bit not in class Arith"]
TVSeq w a
| isTBit a -> vWord w (op w (fromVWord l) (fromVWord r))
| isTBit a -> case fromVWord l of
Left e -> vWordError w e
Right i ->
case fromVWord r of
Left e -> vWordError w e
Right j -> vWord w (op w i j)
| otherwise -> VList (zipWith (go a) (fromVList l) (fromVList r))
TVStream a ->
VList (zipWith (go a) (fromVList l) (fromVList r))
@ -536,18 +598,26 @@ arithBinary op = go
TVRec fs ->
VRecord [ (f, go fty (lookupRecord f l) (lookupRecord f r)) | (f, fty) <- fs ]
divWrap :: Integer -> Integer -> Either EvalError Integer
divWrap _ 0 = Left DivideByZero
divWrap x y = Right (x `div` y)
modWrap :: Integer -> Integer -> Either EvalError Integer
modWrap _ 0 = Left DivideByZero
modWrap x y = Right (x `mod` y)
-- Cmp -------------------------------------------------------------------------
-- | Process two elements based on their lexicographic ordering.
cmpOrder :: (Ordering -> Bool) -> TValue -> Value -> Value -> Value
cmpOrder p ty l r = VBit (p (lexCompare ty l r))
cmpOrder p ty l r = VBit (fmap p (lexCompare ty l r))
-- | Lexicographic ordering on two values.
lexCompare :: TValue -> Value -> Value -> Ordering
lexCompare :: TValue -> Value -> Value -> Either EvalError Ordering
lexCompare ty l r =
case ty of
TVBit ->
compare (fromVBit l) (fromVBit r)
compare <$> fromVBit l <*> fromVBit r
TVSeq _w ety ->
lexList (zipWith (lexCompare ety) (fromVList l) (fromVList r))
TVStream _ ->
@ -563,18 +633,19 @@ lexCompare ty l r =
in lexList (zipWith3 lexCompare tys ls rs)
-- TODO: should we make this strict in both arguments?
lexOrdering :: Ordering -> Ordering -> Ordering
lexOrdering LT _ = LT
lexOrdering EQ y = y
lexOrdering GT _ = GT
lexOrdering :: Either EvalError Ordering -> Either EvalError Ordering -> Either EvalError Ordering
lexOrdering (Left e) _ = Left e
lexOrdering (Right LT) _ = Right LT
lexOrdering (Right EQ) y = y
lexOrdering (Right GT) _ = Right GT
lexList :: [Ordering] -> Ordering
lexList = foldr lexOrdering EQ
lexList :: [Either EvalError Ordering] -> Either EvalError Ordering
lexList = foldr lexOrdering (Right EQ)
-- Logic -----------------------------------------------------------------------
logicNullary :: Bool -> TValue -> Value
logicNullary :: Either EvalError Bool -> TValue -> Value
logicNullary b = go
where
go TVBit = VBit b
@ -590,7 +661,7 @@ logicUnary op = go
go :: TValue -> Value -> Value
go ty val =
case ty of
TVBit -> VBit (op (fromVBit val))
TVBit -> VBit (fmap op (fromVBit val))
TVSeq _w ety -> VList (map (go ety) (fromVList val))
TVStream ety -> VList (map (go ety) (fromVList val))
TVTuple etys -> VTuple (zipWith go etys (fromVTuple val))
@ -603,7 +674,7 @@ logicBinary op = go
go :: TValue -> Value -> Value -> Value
go ty l r =
case ty of
TVBit -> VBit (op (fromVBit l) (fromVBit r))
TVBit -> VBit (liftA2 op (fromVBit l) (fromVBit r))
TVSeq _w ety -> VList (zipWith (go ety) (fromVList l) (fromVList r))
TVStream ety -> VList (zipWith (go ety) (fromVList l) (fromVList r))
TVTuple etys -> VTuple (zipWith3 go etys (fromVTuple l) (fromVTuple r))
@ -620,8 +691,10 @@ shiftV op =
VNumPoly $ \_b ->
VPoly $ \c ->
VFun $ \v ->
VFun $ \i ->
VList (op a (logicNullary False c) (fromVList v) (fromVWord i))
VFun $ \x ->
case fromVWord x of
Left e -> logicNullary (Left e) (tvSeq a c)
Right i -> VList (op a (logicNullary (Right False) c) (fromVList v) i)
shiftLV :: Nat' -> Value -> [Value] -> Integer -> [Value]
shiftLV w z vs i =
@ -639,10 +712,12 @@ rotateV :: (Integer -> [Value] -> Integer -> [Value]) -> Value
rotateV op =
vFinPoly $ \a ->
VNumPoly $ \_b ->
VPoly $ \_c ->
VPoly $ \c ->
VFun $ \v ->
VFun $ \i ->
VList (op a (fromVList v) (fromVWord i))
VFun $ \x ->
case fromVWord x of
Left e -> VList (genericReplicate a (logicNullary (Left e) c))
Right i -> VList (op a (fromVList v) i)
rotateLV :: Integer -> [Value] -> Integer -> [Value]
rotateLV 0 vs _ = vs
@ -683,41 +758,50 @@ indexPrimOne :: (Nat' -> TValue -> [Value] -> Integer -> Value) -> Value
indexPrimOne op =
VNumPoly $ \n ->
VPoly $ \a ->
VNumPoly $ \_i ->
VNumPoly $ \_w ->
VFun $ \l ->
VFun $ \r -> op n a (fromVList l) (fromVWord r)
VFun $ \r ->
case fromVWord r of
Left e -> logicNullary (Left e) a
Right i -> op n a (fromVList l) i
-- | Indexing operations that return many elements.
indexPrimMany :: (Nat' -> TValue -> [Value] -> Integer -> Value) -> Value
indexPrimMany op =
VNumPoly $ \n ->
VPoly $ \a ->
VNumPoly $ \_m ->
VNumPoly $ \_i ->
VNumPoly $ \_m ->
VNumPoly $ \_w ->
VFun $ \l ->
VFun $ \r -> VList [ op n a xs (fromVWord y) | let xs = fromVList l, y <- fromVList r ]
VFun $ \r -> VList [ case fromVWord y of
Left e -> logicNullary (Left e) a
Right i -> op n a xs i
| let xs = fromVList l, y <- fromVList r ]
indexFront :: Nat' -> TValue -> [Value] -> Integer -> Value
indexFront w a vs ix =
case w of
Nat n | n <= ix -> logicNullary (invalidIndex ix) a
Nat n | n <= ix -> logicNullary (Left (InvalidIndex ix)) a
_ -> genericIndex vs ix
indexBack :: Nat' -> TValue -> [Value] -> Integer -> Value
indexBack w a vs ix =
case w of
Nat n | n > ix -> genericIndex vs (n - ix - 1)
| otherwise -> logicNullary (invalidIndex ix) a
| otherwise -> logicNullary (Left (InvalidIndex ix)) a
Inf -> evalPanic "indexBack" ["unexpected infinite sequence"]
updatePrim :: (Nat' -> [Value] -> Integer -> Value -> [Value]) -> Value
updatePrim op =
VNumPoly $ \len ->
VPoly $ \_eltTy ->
VPoly $ \eltTy ->
VNumPoly $ \_idxLen ->
VFun $ \xs ->
VFun $ \idx ->
VFun $ \val -> VList (op len (fromVList xs) (fromVWord idx) val)
VFun $ \val ->
case fromVWord idx of
Left e -> logicNullary (Left e) (tvSeq len eltTy)
Right i -> VList (op len (fromVList xs) i val)
updateFront :: Nat' -> [Value] -> Integer -> Value -> [Value]
updateFront _ vs i x = updateAt vs i x
@ -740,7 +824,7 @@ ppValue val =
VRecord fs -> braces (sep (punctuate comma (map ppField fs)))
where ppField (f,r) = pp f <+> char '=' <+> ppValue r
VTuple vs -> parens (sep (punctuate comma (map ppValue vs)))
VBit b -> text (show b)
VBit b -> text (either show show b)
VList vs -> brackets (fsep (punctuate comma (map ppValue vs)))
VFun _ -> text "<function>"
VPoly _ -> text "<polymorphic value>"
@ -751,6 +835,3 @@ ppValue val =
evalPanic :: String -> [String] -> a
evalPanic cxt = panic ("[Reference Evaluator]" ++ cxt)
invalidIndex :: Integer -> Bool
invalidIndex i = X.throw (InvalidIndex i)