diff --git a/src/Cryptol/Symbolic/Prims.hs b/src/Cryptol/Symbolic/Prims.hs index 5ef9c62a..597a6e80 100644 --- a/src/Cryptol/Symbolic/Prims.hs +++ b/src/Cryptol/Symbolic/Prims.hs @@ -9,6 +9,7 @@ {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE PatternGuards #-} {-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE ViewPatterns #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE TypeSynonymInstances #-} @@ -19,7 +20,7 @@ import Data.List (genericDrop, genericReplicate, genericSplitAt, genericTake, so import Data.Ord (comparing) import Cryptol.Eval.Monad (Eval) -import Cryptol.Eval.Value (BitWord(..), EvalPrims(..), enumerateSeqMap) +import Cryptol.Eval.Value (BitWord(..), EvalPrims(..), enumerateSeqMap, SeqMap(..)) import Cryptol.Prims.Eval (binary, unary, tlamN, arithUnary, arithBinary, Binary, BinArith, logicBinary, logicUnary, zeroV) @@ -51,7 +52,9 @@ instance EvalPrims SBool SWord where evalPrim Decl { .. } = panic "Eval" [ "Unimplemented primitive", show dName ] - iteValue b x y = iteSValue b <$> x <*> y + iteValue b x y + | Just b' <- SBV.svAsBool b = if b' then x else y + | otherwise = iteSValue b <$> x <*> y -- See also Cryptol.Prims.Eval.primTable primTable :: Map.Map Ident Value @@ -79,43 +82,54 @@ primTable = Map.fromList $ map (\(n, v) -> (mkIdent (T.pack n), v)) , ("^" , binary (logicBinary SBV.svXOr SBV.svXOr)) , ("complement" , unary (logicUnary SBV.svNot SBV.svNot)) , ("zero" , tlam zeroV) - ] -{- , ("<<" , -- {m,n,a} (fin n) => [m] a -> [n] -> [m] a tlam $ \m -> - tlam $ \_ -> + tlam $ \n -> tlam $ \a -> - VFun $ \xs -> + VFun $ \xs -> return $ VFun $ \y -> - case xs of - VWord x -> VWord (SBV.svShiftLeft x (fromVWord y)) - _ -> + xs >>= \case + VWord x -> VWord . SBV.svShiftLeft x <$> (fromVWord "<<" =<< y) + x -> do + x' <- fromSeq x + let Nat len = numTValue n let shl :: Integer -> Value - shl i = + shl = case numTValue m of - Inf -> dropV i xs - Nat j | i >= j -> replicateV j a (zeroV a) - | otherwise -> catV (dropV i xs) (replicateV i a (zeroV a)) - - in selectV shl y) - + Inf -> \i -> VStream $ SeqMap $ \idx -> lookupSeqMap x' (i+idx) + Nat j -> \i -> VSeq j (isTBit a) $ SeqMap $ \idx -> + if i+idx >= j then + return $ zeroV a + else + lookupSeqMap x' (i+idx) + in selectV len shl =<< y) , (">>" , -- {m,n,a} (fin n) => [m] a -> [n] -> [m] a tlam $ \m -> - tlam $ \_ -> + tlam $ \n -> tlam $ \a -> - VFun $ \xs -> + VFun $ \xs -> return $ VFun $ \y -> - case xs of - VWord x -> VWord (SBV.svShiftRight x (fromVWord y)) - _ -> - let shr :: Integer -> Value - shr i = - case numTValue m of - Inf -> catV (replicateV i a (zeroV a)) xs - Nat j | i >= j -> replicateV j a (zeroV a) - | otherwise -> catV (replicateV i a (zeroV a)) (takeV (j - i) xs) - in selectV shr y) - + xs >>= \case + VWord x -> VWord . SBV.svShiftLeft x <$> (fromVWord "<<" =<< y) + x -> do + x' <- fromSeq x + let Nat len = numTValue n + let shr :: Integer -> Value + shr = + case numTValue m of + Inf -> \i -> VStream $ SeqMap $ \idx -> + if idx-i < 0 then + return $ zeroV a + else + lookupSeqMap x' (idx-i) + Nat j -> \i -> VSeq j (isTBit a) $ SeqMap $ \idx -> + if idx-i < 0 then + return $ zeroV a + else + lookupSeqMap x' (idx-i) + in selectV len shr =<< y) + ] +{- , ("<<<" , -- {m,n,a} (fin m, fin n) => [m] a -> [n] -> [m] a tlam $ \m -> tlam $ \_ -> @@ -279,19 +293,23 @@ primTable = Map.fromList $ map (\(n, v) -> (mkIdent (T.pack n), v)) , ("random" , panic "Cryptol.Symbolic.Prims.evalECon" [ "can't symbolically evaluae ECRandom" ]) ] +-} +selectV :: Integer -> (Integer -> Value) -> Value -> Eval Value +selectV _len f (VWord v) | Just idx <- SBV.svAsInteger v = return $ f idx -selectV :: (Integer -> Value) -> Value -> Value -selectV f v = sel 0 bits +selectV len f v = sel 0 =<< bits where - bits = map fromVBit (fromSeq v) -- index bits in big-endian order + bits = enumerateSeqMap len <$> fromSeq v -- index bits in big-endian order - sel :: Integer -> [SBool] -> Value - sel offset [] = f offset - sel offset (b : bs) = iteValue b m1 m2 + sel :: Integer -> [Eval Value] -> Eval Value + sel offset [] = return $ f offset + sel offset (b : bs) = do b' <- fromVBit <$> b + iteValue b' m1 m2 where m1 = sel (offset + 2 ^ length bs) bs m2 = sel offset bs +{- asWordList :: [Value] -> Maybe [SWord] asWordList = go id where go :: ([SWord] -> [SWord]) -> [Value] -> Maybe [SWord]