diff --git a/src/Cryptol/Eval/Reference.hs b/src/Cryptol/Eval/Reference.hs index c1df1887..1b2b0c84 100644 --- a/src/Cryptol/Eval/Reference.hs +++ b/src/Cryptol/Eval/Reference.hs @@ -12,7 +12,7 @@ module Cryptol.Eval.Reference where -import qualified Control.Exception as X (throw) +import Control.Applicative (liftA2) import Control.Monad (foldM) import Data.Bits import Data.List @@ -26,8 +26,8 @@ import Cryptol.ModuleSystem.Name (asPrim) import Cryptol.TypeCheck.Solver.InfNat (Nat'(..)) import Cryptol.TypeCheck.AST import Cryptol.Eval.Monad (EvalError(..)) -import Cryptol.Eval.Type (TypeEnv, TValue(..), isTBit, evalValType, evalNumType) -import Cryptol.Prims.Eval (divWrap, modWrap, lg2, divModPoly) +import Cryptol.Eval.Type (TypeEnv, TValue(..), isTBit, evalValType, evalNumType, tvSeq) +import Cryptol.Prims.Eval (lg2, divModPoly) import Cryptol.Utils.Ident (Ident, mkIdent) import Cryptol.Utils.Panic (panic) import Cryptol.Utils.PP @@ -53,16 +53,16 @@ evaluate expr modEnv = return (Right (evalExpr env expr, modEnv), []) -- of a @VBit@ constructor. All other @Value@ and list constructors -- should evaluate without error. data Value - = VRecord [(Ident, Value)] -- ^ @ { .. } @ - | VTuple [Value] -- ^ @ ( .. ) @ - | VBit Bool -- ^ @ Bit @ - | VList [Value] -- ^ @ [n]a @ (either finite or infinite) - | VFun (Value -> Value) -- ^ functions - | VPoly (TValue -> Value) -- ^ polymorphic values (kind *) - | VNumPoly (Nat' -> Value) -- ^ polymorphic values (kind #) + = VRecord [(Ident, Value)] -- ^ @ { .. } @ + | VTuple [Value] -- ^ @ ( .. ) @ + | VBit (Either EvalError Bool) -- ^ @ Bit @ + | VList [Value] -- ^ @ [n]a @ (either finite or infinite) + | VFun (Value -> Value) -- ^ functions + | VPoly (TValue -> Value) -- ^ polymorphic values (kind *) + | VNumPoly (Nat' -> Value) -- ^ polymorphic values (kind #) -- | Destructor for @VBit@. -fromVBit :: Value -> Bool +fromVBit :: Value -> Either EvalError Bool fromVBit (VBit b) = b fromVBit _ = evalPanic "fromVBit" ["Expected a bit"] @@ -81,6 +81,16 @@ fromVFun :: Value -> (Value -> Value) fromVFun (VFun f) = f fromVFun _ = evalPanic "fromVFun" ["Expected a function"] +-- | Destructor for @VPoly@. +fromVPoly :: Value -> (TValue -> Value) +fromVPoly (VPoly f) = f +fromVPoly _ = evalPanic "fromVPoly" ["Expected a polymorphic value"] + +-- | Destructor for @VNumPoly@. +fromVNumPoly :: Value -> (Nat' -> Value) +fromVNumPoly (VNumPoly f) = f +fromVNumPoly _ = evalPanic "fromVNumPoly" ["Expected a polymorphic value"] + -- | Destructor for @VRecord@. fromVRecord :: Value -> [(Ident, Value)] fromVRecord (VRecord fs) = fs @@ -105,12 +115,20 @@ integerToBits w x = go [] w x go bs n a = go (odd a : bs) (n - 1) $! (a `div` 2) -- | Convert a value from a big-endian binary format to an integer. -fromVWord :: Value -> Integer -fromVWord v = bitsToInteger (map fromVBit (fromVList v)) +fromVWord :: Value -> Either EvalError Integer +fromVWord v = fmap bitsToInteger (mapM fromVBit (fromVList v)) -- | Convert an integer to big-endian binary value of the specified width. -vWord :: Integer -> Integer -> Value -vWord w x = VList (map VBit (integerToBits w x)) +vWordValue :: Integer -> Integer -> Value +vWordValue w x = VList (map (VBit . Right) (integerToBits w x)) + +-- | Create a run-time error value of bitvector type. +vWordError :: Integer -> EvalError -> Value +vWordError w e = VList (genericReplicate w (VBit (Left e))) + +vWord :: Integer -> Either EvalError Integer -> Value +vWord w (Left e) = vWordError w e +vWord w (Right x) = vWordValue w x vFinPoly :: (Integer -> Value) -> Value vFinPoly f = VNumPoly g @@ -118,6 +136,25 @@ vFinPoly f = VNumPoly g g (Nat n) = f n g Inf = evalPanic "vFinPoly" ["Expected finite numeric type"] + +-- Conditionals ---------------------------------------------------------------- + +condBit :: Either e Bool -> Either e a -> Either e a -> Either e a +condBit (Left e) _ _ = Left e +condBit (Right b) x y = if b then x else y + +condValue :: Either EvalError Bool -> Value -> Value -> Value +condValue c l r = + case l of + VRecord fs -> VRecord [ (f, condValue c v (lookupRecord f r)) | (f, v) <- fs ] + VTuple vs -> VTuple (zipWith (condValue c) vs (fromVList r)) + VBit b -> VBit (condBit c b (fromVBit r)) + VList vs -> VList (zipWith (condValue c) vs (fromVList r)) + VFun f -> VFun (\v -> condValue c (f v) (fromVFun r v)) + VPoly f -> VPoly (\t -> condValue c (f t) (fromVPoly r t)) + VNumPoly f -> VNumPoly (\n -> condValue c (f n) (fromVNumPoly r n)) + + -- Environments ---------------------------------------------------------------- -- | Evaluation environment. @@ -158,11 +195,7 @@ evalExpr env expr = ETuple es -> VTuple [ evalExpr env e | e <- es ] ERec fields -> VRecord [ (f, evalExpr env e) | (f, e) <- fields ] ESel e sel -> evalSel (evalExpr env e) sel - EIf c t f -> evalExpr env (if fromVBit (evalExpr env c) then t else f) - -- FIXME: this produces an invalid result if evaluation of the - -- condition yields a run-time error or fails to terminate. We - -- should use a zip-like function to push the conditionals down - -- into the leaf bits. + EIf c t f -> condValue (fromVBit (evalExpr env c)) (evalExpr env t) (evalExpr env f) EComp _n _ty h gs -> evalComp env h gs @@ -222,7 +255,7 @@ evalSel val sel = case v of VList vs -> vs !! n _ -> evalPanic "evalSel" - [ "Unexpected value in list selection" ] + ["Unexpected value in list selection."] -- List Comprehensions --------------------------------------------------------- @@ -339,9 +372,9 @@ evalPrim n primTable :: Map.Map Ident Value primTable = Map.fromList $ map (\(n, v) -> (mkIdent (T.pack n), v)) - [ ("+" , binary (arithBinary (const (+)))) - , ("-" , binary (arithBinary (const (-)))) - , ("*" , binary (arithBinary (const (*)))) + [ ("+" , binary (arithBinary (\_ x y -> Right (x + y)))) + , ("-" , binary (arithBinary (\_ x y -> Right (x - y)))) + , ("*" , binary (arithBinary (\_ x y -> Right (x * y)))) , ("/" , binary (arithBinary (const divWrap))) , ("%" , binary (arithBinary (const modWrap))) -- , ("^^" , binary (arithBinary modExp)) @@ -361,12 +394,12 @@ primTable = Map.fromList $ map (\(n, v) -> (mkIdent (T.pack n), v)) , (">>" , shiftV shiftRV) , ("<<<" , rotateV rotateLV) , (">>>" , rotateV rotateRV) - , ("True" , VBit True) - , ("False" , VBit False) + , ("True" , VBit (Right True)) + , ("False" , VBit (Right False)) , ("demote" , vFinPoly $ \val -> vFinPoly $ \bits -> - vWord bits val) + vWordValue bits val) , ("#" , VNumPoly $ \_front -> VNumPoly $ \_back -> @@ -382,7 +415,7 @@ primTable = Map.fromList $ map (\(n, v) -> (mkIdent (T.pack n), v)) , ("update" , updatePrim updateFront) , ("updateEnd" , updatePrim updateBack) - , ("zero" , VPoly (logicNullary False)) + , ("zero" , VPoly (logicNullary (Right False))) , ("join" , VNumPoly $ \_parts -> VNumPoly $ \_each -> @@ -407,32 +440,40 @@ primTable = Map.fromList $ map (\(n, v) -> (mkIdent (T.pack n), v)) vFinPoly $ \next -> vFinPoly $ \bits -> vFinPoly $ \len -> - VList (map (vWord bits) (genericTake len [first, next ..]))) + VList (map (vWordValue bits) (genericTake len [first, next ..]))) , ("fromTo" , vFinPoly $ \first -> vFinPoly $ \lst -> vFinPoly $ \bits -> - VList (map (vWord bits) [first .. lst])) + VList (map (vWordValue bits) [first .. lst])) , ("fromThenTo" , vFinPoly $ \first -> vFinPoly $ \next -> vFinPoly $ \_lst -> vFinPoly $ \bits -> vFinPoly $ \len -> - VList (map (vWord bits) (genericTake len [first, next ..]))) + VList (map (vWordValue bits) (genericTake len [first, next ..]))) , ("infFrom" , vFinPoly $ \bits -> VFun $ \first -> - VList (map (vWord bits) [fromVWord first ..])) + case fromVWord first of + Left e -> VList (repeat (vWordError bits e)) + Right i -> VList (map (vWordValue bits) [i ..])) , ("infFromThen", vFinPoly $ \bits -> VFun $ \first -> VFun $ \next -> - VList (map (vWord bits) [fromVWord first, fromVWord next ..])) + case fromVWord first of + Left e -> VList (repeat (vWordError bits e)) + Right i -> + case fromVWord next of + Left e -> VList (repeat (vWordError bits e)) + Right j -> VList (map (vWordValue bits) [i, j ..])) , ("error" , VPoly $ \a -> VNumPoly $ \_ -> - VFun $ \_s -> logicNullary (error "error") a) + VFun $ \_s -> logicNullary (Left (UserError "error")) a) + -- TODO: obtain error string from argument s , ("reverse" , VNumPoly $ \_a -> VPoly $ \_b -> @@ -450,18 +491,34 @@ primTable = Map.fromList $ map (\(n, v) -> (mkIdent (T.pack n), v)) in vFinPoly $ \a -> vFinPoly $ \b -> VFun $ \x -> - VFun $ \y -> vWord (1 + a + b) (mul 0 (fromVWord x) (fromVWord y) (1+b))) + VFun $ \y -> + case fromVWord x of + Left e -> vWordError (1 + a + b) e + Right i -> + case fromVWord y of + Left e -> vWordError (1 + a + b) e + Right j -> vWordValue (1 + a + b) (mul 0 i j (1+b))) , ("pdiv" , vFinPoly $ \a -> vFinPoly $ \b -> VFun $ \x -> VFun $ \y -> - vWord a (fst (divModPoly (fromVWord x) (fromInteger a) (fromVWord y) (fromInteger b)))) + case fromVWord x of + Left e -> vWordError a e + Right i -> + case fromVWord y of + Left e -> vWordError a e + Right j -> vWordValue a (fst (divModPoly i (fromInteger a) j (fromInteger b)))) , ("pmod" , vFinPoly $ \a -> vFinPoly $ \b -> VFun $ \x -> VFun $ \y -> - vWord b (snd (divModPoly (fromVWord x) (fromInteger a) (fromVWord y) (fromInteger b + 1)))) + case fromVWord x of + Left e -> vWordError b e + Right i -> + case fromVWord y of + Left e -> vWordError b e + Right j -> vWordValue b (snd (divModPoly i (fromInteger a) j (fromInteger b + 1)))) {- , ("random" , VPoly $ \a -> wVFun $ \(bvVal -> x) -> return $ randomV a x) @@ -504,7 +561,7 @@ arithUnary op = go TVBit -> evalPanic "arithUnary" ["Bit not in class Arith"] TVSeq w a - | isTBit a -> vWord w (op w (fromVWord val)) + | isTBit a -> vWord w (op w <$> fromVWord val) | otherwise -> VList (map (go a) (fromVList val)) TVStream a -> VList (map (go a) (fromVList val)) @@ -515,7 +572,7 @@ arithUnary op = go TVRec fs -> VRecord [ (f, go fty (lookupRecord f val)) | (f, fty) <- fs ] -arithBinary :: (Integer -> Integer -> Integer -> Integer) +arithBinary :: (Integer -> Integer -> Integer -> Either EvalError Integer) -> TValue -> Value -> Value -> Value arithBinary op = go where @@ -525,7 +582,12 @@ arithBinary op = go TVBit -> evalPanic "arithBinary" ["Bit not in class Arith"] TVSeq w a - | isTBit a -> vWord w (op w (fromVWord l) (fromVWord r)) + | isTBit a -> case fromVWord l of + Left e -> vWordError w e + Right i -> + case fromVWord r of + Left e -> vWordError w e + Right j -> vWord w (op w i j) | otherwise -> VList (zipWith (go a) (fromVList l) (fromVList r)) TVStream a -> VList (zipWith (go a) (fromVList l) (fromVList r)) @@ -536,18 +598,26 @@ arithBinary op = go TVRec fs -> VRecord [ (f, go fty (lookupRecord f l) (lookupRecord f r)) | (f, fty) <- fs ] +divWrap :: Integer -> Integer -> Either EvalError Integer +divWrap _ 0 = Left DivideByZero +divWrap x y = Right (x `div` y) + +modWrap :: Integer -> Integer -> Either EvalError Integer +modWrap _ 0 = Left DivideByZero +modWrap x y = Right (x `mod` y) + -- Cmp ------------------------------------------------------------------------- -- | Process two elements based on their lexicographic ordering. cmpOrder :: (Ordering -> Bool) -> TValue -> Value -> Value -> Value -cmpOrder p ty l r = VBit (p (lexCompare ty l r)) +cmpOrder p ty l r = VBit (fmap p (lexCompare ty l r)) -- | Lexicographic ordering on two values. -lexCompare :: TValue -> Value -> Value -> Ordering +lexCompare :: TValue -> Value -> Value -> Either EvalError Ordering lexCompare ty l r = case ty of TVBit -> - compare (fromVBit l) (fromVBit r) + compare <$> fromVBit l <*> fromVBit r TVSeq _w ety -> lexList (zipWith (lexCompare ety) (fromVList l) (fromVList r)) TVStream _ -> @@ -563,18 +633,19 @@ lexCompare ty l r = in lexList (zipWith3 lexCompare tys ls rs) -- TODO: should we make this strict in both arguments? -lexOrdering :: Ordering -> Ordering -> Ordering -lexOrdering LT _ = LT -lexOrdering EQ y = y -lexOrdering GT _ = GT +lexOrdering :: Either EvalError Ordering -> Either EvalError Ordering -> Either EvalError Ordering +lexOrdering (Left e) _ = Left e +lexOrdering (Right LT) _ = Right LT +lexOrdering (Right EQ) y = y +lexOrdering (Right GT) _ = Right GT -lexList :: [Ordering] -> Ordering -lexList = foldr lexOrdering EQ +lexList :: [Either EvalError Ordering] -> Either EvalError Ordering +lexList = foldr lexOrdering (Right EQ) -- Logic ----------------------------------------------------------------------- -logicNullary :: Bool -> TValue -> Value +logicNullary :: Either EvalError Bool -> TValue -> Value logicNullary b = go where go TVBit = VBit b @@ -590,7 +661,7 @@ logicUnary op = go go :: TValue -> Value -> Value go ty val = case ty of - TVBit -> VBit (op (fromVBit val)) + TVBit -> VBit (fmap op (fromVBit val)) TVSeq _w ety -> VList (map (go ety) (fromVList val)) TVStream ety -> VList (map (go ety) (fromVList val)) TVTuple etys -> VTuple (zipWith go etys (fromVTuple val)) @@ -603,7 +674,7 @@ logicBinary op = go go :: TValue -> Value -> Value -> Value go ty l r = case ty of - TVBit -> VBit (op (fromVBit l) (fromVBit r)) + TVBit -> VBit (liftA2 op (fromVBit l) (fromVBit r)) TVSeq _w ety -> VList (zipWith (go ety) (fromVList l) (fromVList r)) TVStream ety -> VList (zipWith (go ety) (fromVList l) (fromVList r)) TVTuple etys -> VTuple (zipWith3 go etys (fromVTuple l) (fromVTuple r)) @@ -620,8 +691,10 @@ shiftV op = VNumPoly $ \_b -> VPoly $ \c -> VFun $ \v -> - VFun $ \i -> - VList (op a (logicNullary False c) (fromVList v) (fromVWord i)) + VFun $ \x -> + case fromVWord x of + Left e -> logicNullary (Left e) (tvSeq a c) + Right i -> VList (op a (logicNullary (Right False) c) (fromVList v) i) shiftLV :: Nat' -> Value -> [Value] -> Integer -> [Value] shiftLV w z vs i = @@ -639,10 +712,12 @@ rotateV :: (Integer -> [Value] -> Integer -> [Value]) -> Value rotateV op = vFinPoly $ \a -> VNumPoly $ \_b -> - VPoly $ \_c -> + VPoly $ \c -> VFun $ \v -> - VFun $ \i -> - VList (op a (fromVList v) (fromVWord i)) + VFun $ \x -> + case fromVWord x of + Left e -> VList (genericReplicate a (logicNullary (Left e) c)) + Right i -> VList (op a (fromVList v) i) rotateLV :: Integer -> [Value] -> Integer -> [Value] rotateLV 0 vs _ = vs @@ -683,41 +758,50 @@ indexPrimOne :: (Nat' -> TValue -> [Value] -> Integer -> Value) -> Value indexPrimOne op = VNumPoly $ \n -> VPoly $ \a -> - VNumPoly $ \_i -> + VNumPoly $ \_w -> VFun $ \l -> - VFun $ \r -> op n a (fromVList l) (fromVWord r) + VFun $ \r -> + case fromVWord r of + Left e -> logicNullary (Left e) a + Right i -> op n a (fromVList l) i -- | Indexing operations that return many elements. indexPrimMany :: (Nat' -> TValue -> [Value] -> Integer -> Value) -> Value indexPrimMany op = VNumPoly $ \n -> VPoly $ \a -> - VNumPoly $ \_m -> - VNumPoly $ \_i -> + VNumPoly $ \_m -> + VNumPoly $ \_w -> VFun $ \l -> - VFun $ \r -> VList [ op n a xs (fromVWord y) | let xs = fromVList l, y <- fromVList r ] + VFun $ \r -> VList [ case fromVWord y of + Left e -> logicNullary (Left e) a + Right i -> op n a xs i + | let xs = fromVList l, y <- fromVList r ] indexFront :: Nat' -> TValue -> [Value] -> Integer -> Value indexFront w a vs ix = case w of - Nat n | n <= ix -> logicNullary (invalidIndex ix) a + Nat n | n <= ix -> logicNullary (Left (InvalidIndex ix)) a _ -> genericIndex vs ix indexBack :: Nat' -> TValue -> [Value] -> Integer -> Value indexBack w a vs ix = case w of Nat n | n > ix -> genericIndex vs (n - ix - 1) - | otherwise -> logicNullary (invalidIndex ix) a + | otherwise -> logicNullary (Left (InvalidIndex ix)) a Inf -> evalPanic "indexBack" ["unexpected infinite sequence"] updatePrim :: (Nat' -> [Value] -> Integer -> Value -> [Value]) -> Value updatePrim op = VNumPoly $ \len -> - VPoly $ \_eltTy -> + VPoly $ \eltTy -> VNumPoly $ \_idxLen -> VFun $ \xs -> VFun $ \idx -> - VFun $ \val -> VList (op len (fromVList xs) (fromVWord idx) val) + VFun $ \val -> + case fromVWord idx of + Left e -> logicNullary (Left e) (tvSeq len eltTy) + Right i -> VList (op len (fromVList xs) i val) updateFront :: Nat' -> [Value] -> Integer -> Value -> [Value] updateFront _ vs i x = updateAt vs i x @@ -740,7 +824,7 @@ ppValue val = VRecord fs -> braces (sep (punctuate comma (map ppField fs))) where ppField (f,r) = pp f <+> char '=' <+> ppValue r VTuple vs -> parens (sep (punctuate comma (map ppValue vs))) - VBit b -> text (show b) + VBit b -> text (either show show b) VList vs -> brackets (fsep (punctuate comma (map ppValue vs))) VFun _ -> text "" VPoly _ -> text "" @@ -751,6 +835,3 @@ ppValue val = evalPanic :: String -> [String] -> a evalPanic cxt = panic ("[Reference Evaluator]" ++ cxt) - -invalidIndex :: Integer -> Bool -invalidIndex i = X.throw (InvalidIndex i)