Adapt to use new Data.SBV.Dynamic API in SBV-4.3

This commit is contained in:
Brian Huffman 2015-03-30 10:03:59 -07:00 committed by Adam C. Foltzer
parent 0f04f0753d
commit 8ddec0a2bc
4 changed files with 160 additions and 293 deletions

View File

@ -23,6 +23,8 @@ import Data.Traversable (traverse)
import qualified Control.Exception as X
import qualified Data.SBV as SBV
import qualified Data.SBV.Dynamic as SBV
import qualified Data.SBV.Internals as SBV
import qualified Cryptol.ModuleSystem as M
import qualified Cryptol.ModuleSystem.Env as M
@ -79,6 +81,12 @@ allSatSMTResults (SBV.AllSatResult (_, rs)) = rs
thmSMTResults :: SBV.ThmResult -> [SBV.SMTResult]
thmSMTResults (SBV.ThmResult r) = [r]
allSatWithAny :: [SBV.SMTConfig] -> SBV.Symbolic SBV.SVal -> IO (SBV.Solver, SBV.AllSatResult)
allSatWithAny cfgs s = SBV.allSatWithAny cfgs (fmap (SBV.SBV :: SBV.SVal -> SBV.SBV Bool) s)
proveWithAny :: [SBV.SMTConfig] -> SBV.Symbolic SBV.SVal -> IO (SBV.Solver, SBV.ThmResult)
proveWithAny cfgs s = SBV.proveWithAny cfgs (fmap (SBV.SBV :: SBV.SVal -> SBV.SBV Bool) s)
satProve :: Bool
-> Maybe Int -- ^ satNum
-> (String, Bool, Bool)
@ -102,8 +110,8 @@ satProve isSat mSatNum (proverName, useSolverIte, verbose) edecls mfile (expr, s
when verbose $ liftIO $
putStrLn $ "Got result from " ++ show firstProver
return (tag res)
let runFn | isSat = runProver SBV.allSatWithAny allSatSMTResults
| otherwise = runProver SBV.proveWithAny thmSMTResults
let runFn | isSat = runProver allSatWithAny allSatSMTResults
| otherwise = runProver proveWithAny thmSMTResults
case predArgTypes schema of
Left msg -> return (Right (ProverError msg, modEnv), [])
Right ts -> do when verbose $ putStrLn "Simulating..."
@ -144,6 +152,9 @@ satProve isSat mSatNum (proverName, useSolverIte, verbose) edecls mfile (expr, s
[ "attempted to evaluate bogus boolean for pretty-printing" ]
return (Right (esatexprs, modEnv), [])
compileToSMTLib :: Bool -> Bool -> SBV.Symbolic SBV.SVal -> IO String
compileToSMTLib a b s = SBV.compileToSMTLib a b (fmap (SBV.SBV :: SBV.SVal -> SBV.SBV Bool) s)
satProveOffline :: Bool
-> Bool
-> Bool
@ -164,7 +175,7 @@ satProveOffline isSat useIte vrb edecls mfile (expr, schema) =
let v = evalExpr env expr
let satWord | isSat = "satisfiability"
| otherwise = "validity"
txt <- SBV.compileToSMTLib True isSat $ do
txt <- compileToSMTLib True isSat $ do
args <- mapM tyFn ts
b <- return $! fromVBit (foldl fromVFun v args)
liftIO $ putStrLn $
@ -265,8 +276,8 @@ predArgTypes schema@(Forall ts ps ty)
forallFinType :: FinType -> SBV.Symbolic Value
forallFinType ty =
case ty of
FTBit -> VBit <$> SBV.forall_
FTSeq 0 FTBit -> return $ VWord (SBV.literal (bv 0 0))
FTBit -> VBit <$> forallSBool_
FTSeq 0 FTBit -> return $ VWord (literalSWord 0 0)
FTSeq n FTBit -> VWord <$> (forallBV_ n)
FTSeq n t -> VSeq False <$> replicateM n (forallFinType t)
FTTuple ts -> VTuple <$> mapM forallFinType ts
@ -275,8 +286,8 @@ forallFinType ty =
existsFinType :: FinType -> SBV.Symbolic Value
existsFinType ty =
case ty of
FTBit -> VBit <$> SBV.exists_
FTSeq 0 FTBit -> return $ VWord (SBV.literal (bv 0 0))
FTBit -> VBit <$> existsSBool_
FTSeq 0 FTBit -> return $ VWord (literalSWord 0 0)
FTSeq n FTBit -> VWord <$> existsBV_ n
FTSeq n t -> VSeq False <$> replicateM n (existsFinType t)
FTTuple ts -> VTuple <$> mapM existsFinType ts
@ -324,6 +335,9 @@ lookupType p env = Map.lookup p (envTypes env)
-- Expressions -----------------------------------------------------------------
svSBranch :: SBool -> Value -> Value -> Value
svSBranch t x y = SBV.sBranch (SBV.SBV t) x y
evalExpr :: Env -> Expr -> Value
evalExpr env expr =
case expr of
@ -333,7 +347,7 @@ evalExpr env expr =
ERec fields -> VRecord [ (f, eval e) | (f, e) <- fields ]
ESel e sel -> evalSel sel (eval e)
EIf b e1 e2 -> evalIf (fromVBit (eval b)) (eval e1) (eval e2)
where evalIf = if envIteSolver env then SBV.sBranch else SBV.ite
where evalIf = if envIteSolver env then svSBranch else iteValue
EComp ty e mss -> evalComp env (evalType env ty) e mss
EVar n -> case lookupVar n env of
Just x -> x
@ -377,7 +391,7 @@ evalSel sel v =
_ -> panic "Cryptol.Symbolic.evalSel" [ "Record selector applied to non-record" ]
ListSel n _ -> case v of
VWord s -> VBit (SBV.sbvTestBit s n)
VWord s -> VBit (SBV.svTestBit s n)
_ -> fromSeq v !! n -- 0-based indexing
-- Declarations ----------------------------------------------------------------

View File

@ -13,193 +13,36 @@
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE ViewPatterns #-}
module Cryptol.Symbolic.BitVector where
module Cryptol.Symbolic.BitVector
( SBool, SWord
, literalSWord
, fromBitsLE
, forallBV_, existsBV_
, forallSBool_, existsSBool_
) where
import Data.Bits
import Data.List (foldl')
import Control.Monad (replicateM)
import System.Random
import Data.SBV.Bridge.Yices
import Data.SBV.Internals
import Data.SBV.Dynamic
import Cryptol.Utils.Panic
type SBool = SVal
type SWord = SVal
-- BitVector type --------------------------------------------------------------
data BitVector = BV { signedcxt :: Bool, width :: !Int, val :: !Integer }
deriving (Eq, Ord, Show)
-- ^ Invariant: BV w x requires that 0 <= w and 0 <= x < 2^w.
bitMask :: Int -> Integer
bitMask w = bit w - 1
-- | Smart constructor for bitvectors.
bv :: Int -> Integer -> BitVector
bv = sbv False
sbv :: Bool -> Int -> Integer -> BitVector
sbv b w x = BV b w (x .&. bitMask w)
unsigned :: Int -> Integer -> Integer
unsigned w x = x + bit w
signed :: Int -> Integer -> Integer
signed w x
| w > 0 && testBit x (w - 1) = x - bit w
| otherwise = x
same :: Int -> Int -> Int
same m n | m == n = m
| otherwise = panic "Cryptol.Symbolic.BitVector.same"
[ "BitVector size mismatch: " ++ show (m, n) ]
instance SignCast SWord SWord where
signCast (SBV (KBounded _ w) (Left (cwVal -> (CWInteger x)))) =
SBV k (Left (CW k (CWInteger (signed w x)))) where
k = KBounded True w
signCast x@(SBV (KBounded _ w) _) = SBV k (Right (cache y)) where
k = KBounded True w
y st = do xsw <- sbvToSW st x
newExpr st k (SBVApp (Extract (w - 1) 0) [xsw])
signCast _ = panic "Cryptol.Symbolic.BitVector"
[ "signCast called on non-bitvector value" ]
unsignCast (SBV (KBounded _ w) (Left (cwVal -> (CWInteger x)))) =
SBV k (Left (CW k (CWInteger (unsigned w x)))) where
k = KBounded False w
unsignCast x@(SBV (KBounded _ w) _) = SBV k (Right (cache y)) where
k = KBounded False w
y st = do xsw <- sbvToSW st x
newExpr st k (SBVApp (Extract (w - 1) 0) [xsw])
unsignCast _ = panic "Cryptol.Symbolic.BitVector"
[ "unsignCast called on non-bitvector value" ]
instance Num BitVector where
fromInteger n = panic "Cryptol.Symbolic.BitVector"
[ "fromInteger " ++ show n ++ " :: BitVector" ]
BV s m x + BV _ n y = sbv s (same m n) (x + y)
BV s m x - BV _ n y = sbv s (same m n) (x - y)
BV s m x * BV _ n y = sbv s (same m n) (x * y)
negate (BV s m x) = sbv s m (- x)
abs = id
signum (BV s m _) = sbv s m 1
instance Bits BitVector where
BV s m x .&. BV _ n y = BV s (same m n) (x .&. y)
BV s m x .|. BV _ n y = BV s (same m n) (x .|. y)
BV s m x `xor` BV _ n y = BV s (same m n) (x `xor` y)
complement (BV s m x) = BV s m (x `xor` bitMask m)
shift (BV s m x) i = sbv s m (shift x i)
rotate (BV s m x) i = sbv s m (shift x j .|. shift x (j - m))
where j = i `mod` m
bit _i = panic "Cryptol.Symbolic.BitVector"
[ "bit: can't determine width" ]
setBit (BV s m x) i = BV s m (setBit x i)
clearBit (BV s m x) i = BV s m (clearBit x i)
complementBit (BV s m x) i = BV s m (complementBit x i)
testBit (BV _ _ x) i = testBit x i
bitSize x
| Just m <- bitSizeMaybe x = m
| otherwise = panic "Cryptol.Symbolic.BitVector"
[ "bitSize should be total for BitVector" ]
bitSizeMaybe (BV _ m _) = Just m
isSigned (BV s _ _) = s
popCount (BV _ _ x) = popCount x
instance FiniteBits BitVector where
finiteBitSize (BV _ m _) = m
instance FiniteBits (SBV BitVector) where
finiteBitSize (SBV (KBounded _ w) _) = w
finiteBitSize _ = panic "Cryptol.Symbolic.BitVector"
[ "finiteBitSize called on non-bitvector value" ]
--------------------------------------------------------------------------------
-- SBV class instances
type SWord = SBV BitVector
instance HasKind BitVector where
kindOf (BV s w _) = KBounded s w
instance SymWord BitVector where
literal (BV s w x) = SBV k (Left (mkConstCW k x))
where k = KBounded s w
fromCW c@(CW (KBounded s w) _) = BV s w (fromCW c)
fromCW c = panic "Cryptol.Symbolic.BitVector"
[ "fromCW: Unsupported non-integral value: " ++ show c ]
mkSymWord _ _ = panic "Cryptol.Symbolic.BitVector"
[ "mkSymWord unimplemented for type BitVector" ]
instance SIntegral BitVector where
instance FromBits (SBV BitVector) where
fromBitsLE bs = foldl' f (literalSWord 0 0) bs
where f w b = cat (ite b (literalSWord 1 1) (literalSWord 1 0)) w
instance SDivisible BitVector where
sQuotRem (BV _ m x) (BV _ n y) = (BV False w q, BV False w r)
where (q, r) = quotRem x y
w = same m n
sDivMod (BV _ m x) (BV _ n y) = (BV False w q, BV False w r)
where (q, r) = divMod x y
w = same m n
instance SDivisible (SBV BitVector) where
sQuotRem = liftQRem
sDivMod = liftDMod
extract :: Int -> Int -> SWord -> SWord
extract i j x@(SBV (KBounded s _) _) =
case x of
_ | i < j -> SBV k (Left (CW k (CWInteger 0)))
SBV _ (Left cw) ->
case cw of
CW _ (CWInteger v) -> SBV k (Left (normCW (CW k (CWInteger (v `shiftR` j)))))
_ -> panic "Cryptol.Symbolic.BitVector.extract" [ "non-integer concrete word" ]
_ -> SBV k (Right (cache y))
where y st = do sw <- sbvToSW st x
newExpr st k (SBVApp (Extract i j) [sw])
where
k = KBounded s (i - j + 1)
extract _ _ _ = panic "Cryptol.Symbolic.BitVector.extract" [ "non-bitvector value" ]
cat :: SWord -> SWord -> SWord
cat x y | finiteBitSize x == 0 = y
| finiteBitSize y == 0 = x
cat x@(SBV _ (Left a)) y@(SBV _ (Left b)) =
case (a, b) of
(CW _ (CWInteger m), CW _ (CWInteger n)) ->
SBV k (Left (CW k (CWInteger ((m `shiftL` (finiteBitSize y) .|. n)))))
_ -> panic "Cryptol.Symbolic.BitVector.cat" [ "non-integer concrete word" ]
where k = KBounded False (finiteBitSize x + finiteBitSize y)
cat x y = SBV k (Right (cache z))
where k = KBounded False (finiteBitSize x + finiteBitSize y)
z st = do xsw <- sbvToSW st x
ysw <- sbvToSW st y
newExpr st k (SBVApp Join [xsw, ysw])
fromBitsLE :: [SVal] -> SWord
fromBitsLE bs = foldl' f (literalSWord 0 0) bs
where f w b = svJoin (svToWord1 b) w
literalSWord :: Int -> Integer -> SWord
literalSWord w i = genLiteral (KBounded False w) i
literalSWord w i = svInteger (KBounded False w) i
randomSBVBitVector :: Int -> IO (SBV BitVector)
randomSBVBitVector w = do
bs <- replicateM w randomIO
let x = sum [ bit i | (i, b) <- zip [0..] bs, b ]
return (literal (bv w x))
forallBV_ :: Int -> Symbolic SWord
forallBV_ w = svMkSymVar (Just ALL) (KBounded False w) Nothing
mkSymBitVector :: Maybe Quantifier -> Maybe String -> Int -> Symbolic (SBV BitVector)
mkSymBitVector mbQ mbNm w =
mkSymSBVWithRandom (randomSBVBitVector w) mbQ (KBounded False w) mbNm
existsBV_ :: Int -> Symbolic SWord
existsBV_ w = svMkSymVar (Just EX) (KBounded False w) Nothing
forallBV :: String -> Int -> Symbolic (SBV BitVector)
forallBV nm w = mkSymBitVector (Just ALL) (Just nm) w
forallSBool_ :: Symbolic SBool
forallSBool_ = svMkSymVar (Just ALL) KBool Nothing
forallBV_ :: Int -> Symbolic (SBV BitVector)
forallBV_ w = mkSymBitVector (Just ALL) Nothing w
existsBV :: String -> Int -> Symbolic (SBV BitVector)
existsBV nm w = mkSymBitVector (Just EX) (Just nm) w
existsBV_ :: Int -> Symbolic (SBV BitVector)
existsBV_ w = mkSymBitVector (Just EX) Nothing w
existsSBool_ :: Symbolic SBool
existsSBool_ = svMkSymVar (Just EX) KBool Nothing

View File

@ -12,10 +12,10 @@
module Cryptol.Symbolic.Prims where
import Control.Applicative
import Data.Bits
import Data.List (genericDrop, genericReplicate, genericSplitAt, genericTake, sortBy, transpose)
import Data.Ord (comparing)
import Cryptol.Eval.Value (BitWord(..))
import Cryptol.Prims.Eval (binary, unary, tlamN)
import Cryptol.Prims.Syntax (ECon(..))
import Cryptol.Symbolic.BitVector
@ -25,7 +25,9 @@ import Cryptol.TypeCheck.Solver.InfNat(Nat'(..), nMul)
import Cryptol.Utils.Panic
import qualified Data.SBV as SBV
import Data.SBV (SBool)
import qualified Data.SBV.Internals as SBV (SBV(..))
import qualified Data.SBV.Dynamic as SBV
--import Data.SBV (SBool)
traverseSnd :: Functor f => (a -> f b) -> (t, a) -> f (t, b)
traverseSnd f (x, y) = (,) x <$> f y
@ -36,25 +38,25 @@ traverseSnd f (x, y) = (,) x <$> f y
evalECon :: ECon -> Value
evalECon econ =
case econ of
ECTrue -> VBit SBV.true
ECFalse -> VBit SBV.false
ECTrue -> VBit SBV.svTrue
ECFalse -> VBit SBV.svFalse
ECDemote -> ecDemoteV -- Converts a numeric type into its corresponding value.
-- { val, bits } (fin val, fin bits, bits >= width val) => [bits]
ECPlus -> binary (arithBinary (+)) -- {a} (Arith a) => a -> a -> a
ECMinus -> binary (arithBinary (-)) -- {a} (Arith a) => a -> a -> a
ECMul -> binary (arithBinary (*)) -- {a} (Arith a) => a -> a -> a
ECDiv -> binary (arithBinary (SBV.sQuot)) -- {a} (Arith a) => a -> a -> a
ECMod -> binary (arithBinary (SBV.sRem)) -- {a} (Arith a) => a -> a -> a
ECPlus -> binary (arithBinary SBV.svPlus) -- {a} (Arith a) => a -> a -> a
ECMinus -> binary (arithBinary SBV.svMinus) -- {a} (Arith a) => a -> a -> a
ECMul -> binary (arithBinary SBV.svTimes) -- {a} (Arith a) => a -> a -> a
ECDiv -> binary (arithBinary SBV.svQuot) -- {a} (Arith a) => a -> a -> a
ECMod -> binary (arithBinary SBV.svRem) -- {a} (Arith a) => a -> a -> a
ECExp -> binary (arithBinary sExp) -- {a} (Arith a) => a -> a -> a
ECLg2 -> unary (arithUnary sLg2) -- {a} (Arith a) => a -> a
ECNeg -> unary (arithUnary negate)
ECNeg -> unary (arithUnary SBV.svUNeg)
ECLt -> binary (cmpBinary cmpLt cmpLt SBV.false)
ECGt -> binary (cmpBinary cmpGt cmpGt SBV.false)
ECLtEq -> binary (cmpBinary cmpLtEq cmpLtEq SBV.true)
ECGtEq -> binary (cmpBinary cmpGtEq cmpGtEq SBV.true)
ECEq -> binary (cmpBinary cmpEq cmpEq SBV.true)
ECNotEq -> binary (cmpBinary cmpNotEq cmpNotEq SBV.false)
ECLt -> binary (cmpBinary cmpLt cmpLt SBV.svFalse)
ECGt -> binary (cmpBinary cmpGt cmpGt SBV.svFalse)
ECLtEq -> binary (cmpBinary cmpLtEq cmpLtEq SBV.svTrue)
ECGtEq -> binary (cmpBinary cmpGtEq cmpGtEq SBV.svTrue)
ECEq -> binary (cmpBinary cmpEq cmpEq SBV.svTrue)
ECNotEq -> binary (cmpBinary cmpNotEq cmpNotEq SBV.svFalse)
-- FIXME: the next 4 "primitives" should be defined in the Cryptol prelude.
ECFunEq -> -- {a b} (Cmp b) => (a -> b) -> (a -> b) -> a -> Bit
@ -63,7 +65,7 @@ evalECon econ =
tlam $ \b ->
VFun $ \f ->
VFun $ \g ->
VFun $ \x -> cmpBinary cmpEq cmpEq SBV.true b (fromVFun f x) (fromVFun g x)
VFun $ \x -> cmpBinary cmpEq cmpEq SBV.svTrue b (fromVFun f x) (fromVFun g x)
ECFunNotEq -> -- {a b} (Cmp b) => (a -> b) -> (a -> b) -> a -> Bit
-- (f !== g) x = (f x != g x)
@ -71,24 +73,24 @@ evalECon econ =
tlam $ \b ->
VFun $ \f ->
VFun $ \g ->
VFun $ \x -> cmpBinary cmpNotEq cmpNotEq SBV.false b (fromVFun f x) (fromVFun g x)
VFun $ \x -> cmpBinary cmpNotEq cmpNotEq SBV.svFalse b (fromVFun f x) (fromVFun g x)
ECMin -> -- {a} (Cmp a) => a -> a -> a
-- min x y = if x <= y then x else y
binary $ \a x y ->
let c = cmpBinary cmpLtEq cmpLtEq SBV.false a x y
in SBV.ite (fromVBit c) x y
let c = cmpBinary cmpLtEq cmpLtEq SBV.svFalse a x y
in SBV.ite (SBV.SBV (fromVBit c)) x y
ECMax -> -- {a} (Cmp a) => a -> a -> a
-- max x y = if y <= x then x else y
binary $ \a x y ->
let c = cmpBinary cmpLtEq cmpLtEq SBV.false a y x
in SBV.ite (fromVBit c) x y
let c = cmpBinary cmpLtEq cmpLtEq SBV.svFalse a y x
in SBV.ite (SBV.SBV (fromVBit c)) x y
ECAnd -> binary (logicBinary (SBV.&&&) (.&.))
ECOr -> binary (logicBinary (SBV.|||) (.|.))
ECXor -> binary (logicBinary (SBV.<+>) xor)
ECCompl -> unary (logicUnary (SBV.bnot) complement)
ECAnd -> binary (logicBinary SBV.svAnd SBV.svAnd)
ECOr -> binary (logicBinary SBV.svOr SBV.svOr)
ECXor -> binary (logicBinary SBV.svXOr SBV.svXOr)
ECCompl -> unary (logicUnary SBV.svNot SBV.svNot)
ECZero -> VPoly zeroV
ECShiftL -> -- {m,n,a} (fin n) => [m] a -> [n] -> [m] a
@ -98,7 +100,7 @@ evalECon econ =
VFun $ \xs ->
VFun $ \y ->
case xs of
VWord x -> VWord (SBV.sbvShiftLeft x (fromVWord y))
VWord x -> VWord (SBV.svShiftLeft x (fromVWord y))
_ -> selectV shl y
where
shl :: Integer -> Value
@ -115,7 +117,7 @@ evalECon econ =
VFun $ \xs ->
VFun $ \y ->
case xs of
VWord x -> VWord (SBV.sbvShiftRight x (fromVWord y))
VWord x -> VWord (SBV.svShiftRight x (fromVWord y))
_ -> selectV shr y
where
shr :: Integer -> Value
@ -132,7 +134,7 @@ evalECon econ =
VFun $ \xs ->
VFun $ \y ->
case xs of
VWord x -> VWord (SBV.sbvRotateLeft x (fromVWord y))
VWord x -> VWord (SBV.svRotateLeft x (fromVWord y))
_ -> selectV rol y
where
rol :: Integer -> Value
@ -146,7 +148,7 @@ evalECon econ =
VFun $ \xs ->
VFun $ \y ->
case xs of
VWord x -> VWord (SBV.sbvRotateRight x (fromVWord y))
VWord x -> VWord (SBV.svRotateRight x (fromVWord y))
_ -> selectV ror y
where
ror :: Integer -> Value
@ -235,13 +237,13 @@ evalECon econ =
ECInfFrom ->
tlam $ \(finTValue -> bits) ->
lam $ \(fromVWord -> first) ->
toStream [ VWord (first + SBV.literal (bv (fromInteger bits) i)) | i <- [0 ..] ]
toStream [ VWord (SBV.svPlus first (literalSWord (fromInteger bits) i)) | i <- [0 ..] ]
ECInfFromThen -> -- {a} (fin a) => [a] -> [a] -> [inf][a]
tlam $ \_ ->
lam $ \(fromVWord -> first) ->
lam $ \(fromVWord -> next) ->
toStream (map VWord (iterate (+ (next - first)) first))
toStream (map VWord (iterate (SBV.svPlus (SBV.svMinus next first)) first))
-- {at,len} (fin len) => [len][8] -> at
ECError ->
@ -256,10 +258,10 @@ evalECon econ =
VFun $ \v2 ->
let k = max 1 (i + j) - 1
mul _ [] ps = ps
mul as (b:bs) ps = mul (SBV.false : as) bs (ites b (as `addPoly` ps) ps)
mul as (b:bs) ps = mul (SBV.svFalse : as) bs (ites b (as `addPoly` ps) ps)
xs = map fromVBit (fromSeq v1)
ys = map fromVBit (fromSeq v2)
zs = take (fromInteger k) (mul xs ys [] ++ repeat SBV.false)
zs = take (fromInteger k) (mul xs ys [] ++ repeat SBV.svFalse)
in VSeq True (map VBit zs)
ECPDiv -> -- {a,b} (fin a, fin b) => [a] -> [b] -> [a]
@ -269,7 +271,7 @@ evalECon econ =
VFun $ \v2 ->
let xs = map fromVBit (fromSeq v1)
ys = map fromVBit (fromSeq v2)
zs = take (fromInteger i) (fst (mdp (reverse xs) (reverse ys)) ++ repeat SBV.false)
zs = take (fromInteger i) (fst (mdp (reverse xs) (reverse ys)) ++ repeat SBV.svFalse)
in VSeq True (map VBit (reverse zs))
ECPMod -> -- {a,b} (fin a, fin b) => [a] -> [b+1] -> [b]
@ -279,7 +281,7 @@ evalECon econ =
VFun $ \v2 ->
let xs = map fromVBit (fromSeq v1)
ys = map fromVBit (fromSeq v2)
zs = take (fromInteger j) (snd (mdp (reverse xs) (reverse ys)) ++ repeat SBV.false)
zs = take (fromInteger j) (snd (mdp (reverse xs) (reverse ys)) ++ repeat SBV.svFalse)
in VSeq True (map VBit (reverse zs))
ECRandom -> panic "Cryptol.Symbolic.Prims.evalECon"
@ -293,7 +295,7 @@ selectV f v = sel 0 bits
sel :: Integer -> [SBool] -> Value
sel offset [] = f offset
sel offset (b : bs) = SBV.ite b m1 m2
sel offset (b : bs) = iteValue b m1 m2
where m1 = sel (offset + 2 ^ length bs) bs
m2 = sel offset bs
@ -311,9 +313,9 @@ nthV err v n =
case v of
VStream xs -> nth err xs (fromInteger n)
VSeq _ xs -> nth err xs (fromInteger n)
VWord x -> let i = finiteBitSize x - 1 - fromInteger n
VWord x -> let i = SBV.svBitSize x - 1 - fromInteger n
in if i < 0 then err else
VBit (SBV.sbvTestBit x i)
VBit (SBV.svTestBit x i)
_ -> err
mapV :: Bool -> (Value -> Value) -> Value -> Value
@ -325,8 +327,8 @@ mapV isBit f v =
catV :: Value -> Value -> Value
catV xs (VStream ys) = VStream (fromSeq xs ++ ys)
catV (VWord x) ys = VWord (cat x (fromVWord ys))
catV xs (VWord y) = VWord (cat (fromVWord xs) y)
catV (VWord x) ys = VWord (SBV.svJoin x (fromVWord ys))
catV xs (VWord y) = VWord (SBV.svJoin (fromVWord xs) y)
catV (VSeq b xs) (VSeq _ ys) = VSeq b (xs ++ ys)
catV _ _ = panic "Cryptol.Symbolic.Prims.catV" [ "non-concatenable value" ]
@ -336,13 +338,13 @@ dropV n xs =
case xs of
VSeq b xs' -> VSeq b (genericDrop n xs')
VStream xs' -> VStream (genericDrop n xs')
VWord w -> VWord $ extract (finiteBitSize w - 1 - fromInteger n) 0 w
VWord w -> VWord $ SBV.svExtract (SBV.svBitSize w - 1 - fromInteger n) 0 w
_ -> panic "Cryptol.Symbolic.Prims.dropV" [ "non-droppable value" ]
takeV :: Integer -> Value -> Value
takeV n xs =
case xs of
VWord w -> VWord $ extract (finiteBitSize w - 1) (finiteBitSize w - fromInteger n) w
VWord w -> VWord $ SBV.svExtract (SBV.svBitSize w - 1) (SBV.svBitSize w - fromInteger n) w
VSeq b xs' -> VSeq b (genericTake n xs')
VStream xs' -> VSeq b (genericTake n xs')
where b = case xs' of VBit _ : _ -> True
@ -355,7 +357,7 @@ ecDemoteV :: Value
ecDemoteV = tlam $ \valT ->
tlam $ \bitT ->
case (numTValue valT, numTValue bitT) of
(Nat v, Nat bs) -> VWord (SBV.literal (bv (fromInteger bs) v))
(Nat v, Nat bs) -> VWord (literalSWord (fromInteger bs) v)
_ -> evalPanic "Cryptol.Prove.evalECon"
["Unexpected Inf in constant."
, show valT
@ -418,19 +420,19 @@ arithUnary op = loop . toTypeVal
TVFun _ t -> VFun (\x -> loop t (fromVFun v x))
sExp :: SWord -> SWord -> SWord
sExp x y = go (SBV.blastLE y)
where go [] = SBV.literal (bv (finiteBitSize x) 1)
go (b : bs) = SBV.ite b (x * s) s
sExp x y = go (unpackWord y)
where go [] = literalSWord (SBV.svBitSize x) 1
go (b : bs) = SBV.svIte b (SBV.svTimes x s) s
where a = go bs
s = a * a
s = SBV.svTimes a a
-- | Ceiling (log_2 x)
sLg2 :: SWord -> SWord
sLg2 x = go 0
where
lit n = SBV.literal (bv (finiteBitSize x) n)
go i | i < finiteBitSize x = SBV.ite ((SBV..<=) x (lit (2^i))) (lit (toInteger i)) (go (i + 1))
| otherwise = lit (toInteger i)
lit n = literalSWord (SBV.svBitSize x) n
go i | i < SBV.svBitSize x = SBV.svIte (SBV.svLessEq x (lit (2^i))) (lit (toInteger i)) (go (i + 1))
| otherwise = lit (toInteger i)
-- Cmp -------------------------------------------------------------------------
@ -461,19 +463,19 @@ cmpValue fb fw = cmp
cmpValues (x1 : xs1) (x2 : xs2) k = cmp x1 x2 (cmpValues xs1 xs2 k)
cmpValues _ _ k = k
cmpEq :: SBV.EqSymbolic a => a -> a -> SBool -> SBool
cmpEq x y k = (SBV.&&&) ((SBV..==) x y) k
cmpEq :: SWord -> SWord -> SBool -> SBool
cmpEq x y k = SBV.svAnd (SBV.svEqual x y) k
cmpNotEq :: SBV.EqSymbolic a => a -> a -> SBool -> SBool
cmpNotEq x y k = (SBV.|||) ((SBV../=) x y) k
cmpNotEq :: SWord -> SWord -> SBool -> SBool
cmpNotEq x y k = SBV.svOr (SBV.svNotEqual x y) k
cmpLt, cmpGt :: SBV.OrdSymbolic a => a -> a -> SBool -> SBool
cmpLt x y k = (SBV.|||) ((SBV..<) x y) (cmpEq x y k)
cmpGt x y k = (SBV.|||) ((SBV..>) x y) (cmpEq x y k)
cmpLt, cmpGt :: SWord -> SWord -> SBool -> SBool
cmpLt x y k = SBV.svOr (SBV.svLessThan x y) (cmpEq x y k)
cmpGt x y k = SBV.svOr (SBV.svGreaterThan x y) (cmpEq x y k)
cmpLtEq, cmpGtEq :: SBV.OrdSymbolic a => a -> a -> SBool -> SBool
cmpLtEq x y k = (SBV.&&&) ((SBV..<=) x y) (cmpNotEq x y k)
cmpGtEq x y k = (SBV.&&&) ((SBV..>=) x y) (cmpNotEq x y k)
cmpLtEq, cmpGtEq :: SWord -> SWord -> SBool -> SBool
cmpLtEq x y k = SBV.svAnd (SBV.svLessEq x y) (cmpNotEq x y k)
cmpGtEq x y k = SBV.svAnd (SBV.svGreaterEq x y) (cmpNotEq x y k)
cmpBinary :: (SBool -> SBool -> SBool -> SBool)
-> (SWord -> SWord -> SBool -> SBool)
@ -500,8 +502,8 @@ zeroV = go . toTypeVal
where
go ty =
case ty of
TVBit -> VBit SBV.false
TVSeq n TVBit -> VWord (SBV.literal (bv n 0))
TVBit -> VBit SBV.svFalse
TVSeq n TVBit -> VWord (literalSWord n 0)
TVSeq n t -> VSeq False (replicate n (go t))
TVStream t -> VStream (repeat (go t))
TVTuple ts -> VTuple [ go t | t <- ts ]
@ -575,7 +577,7 @@ fromThenV =
case (first, next, len, bits) of
(Nat first', Nat next', Nat len', Nat bits') ->
let nums = enumFromThen first' next'
lit i = VWord (SBV.literal (bv (fromInteger bits') i))
lit i = VWord (literalSWord (fromInteger bits') i)
in VSeq False (genericTake len' (map lit nums))
_ -> evalPanic "fromThenV" ["invalid arguments"]
@ -590,7 +592,7 @@ fromToV =
(Nat first', Nat lst', Nat bits') ->
let nums = enumFromThenTo first' (first' + 1) lst'
len = 1 + (lst' - first')
lit i = VWord (SBV.literal (bv (fromInteger bits') i))
lit i = VWord (literalSWord (fromInteger bits') i)
in VSeq False (genericTake len (map lit nums))
_ -> evalPanic "fromThenV" ["invalid arguments"]
@ -607,7 +609,7 @@ fromThenToV =
(Nat first', Nat next', Nat lst', Nat len', Nat bits') ->
let nums = enumFromThenTo first' next' lst'
lit i = VWord (SBV.literal (bv (fromInteger bits') i))
lit i = VWord (literalSWord (fromInteger bits') i)
in VSeq False (genericTake len' (map lit nums))
_ -> evalPanic "fromThenV" ["invalid arguments"]
@ -621,25 +623,25 @@ fromThenToV =
addPoly :: [SBool] -> [SBool] -> [SBool]
addPoly xs [] = xs
addPoly [] ys = ys
addPoly (x:xs) (y:ys) = x SBV.<+> y : addPoly xs ys
addPoly (x:xs) (y:ys) = SBV.svXOr x y : addPoly xs ys
ites :: SBool -> [SBool] -> [SBool] -> [SBool]
ites s xs ys
| Just t <- SBV.unliteral s
| Just t <- SBV.svAsBool s
= if t then xs else ys
| True
= go xs ys
where go [] [] = []
go [] (b:bs) = SBV.ite s SBV.false b : go [] bs
go (a:as) [] = SBV.ite s a SBV.false : go as []
go (a:as) (b:bs) = SBV.ite s a b : go as bs
go [] (b:bs) = SBV.svIte s SBV.svFalse b : go [] bs
go (a:as) [] = SBV.svIte s a SBV.svFalse : go as []
go (a:as) (b:bs) = SBV.svIte s a b : go as bs
-- conservative over-approximation of the degree
degree :: [SBool] -> Int
degree xs = walk (length xs - 1) $ reverse xs
where walk n [] = n
walk n (b:bs)
| Just t <- SBV.unliteral b
| Just t <- SBV.svAsBool b
= if t then n else walk (n-1) bs
| True
= n -- over-estimate
@ -653,13 +655,13 @@ mdp xs ys = go (length ys - 1) (reverse ys)
| True = let (rqs, rrs) = go (n-1) bs
in (ites b (reverse qs) rqs, ites b rs rrs)
where degQuot = degTop - n
ys' = replicate degQuot SBV.false ++ ys
ys' = replicate degQuot SBV.svFalse ++ ys
(qs, rs) = divx (degQuot+1) degTop xs ys'
-- return the element at index i; if not enough elements, return false
-- N.B. equivalent to '(xs ++ repeat false) !! i', but more efficient
idx :: [SBool] -> Int -> SBool
idx [] _ = SBV.false
idx [] _ = SBV.svFalse
idx (x:_) 0 = x
idx (_:xs) i = idx xs (i-1)

View File

@ -18,19 +18,19 @@ module Cryptol.Symbolic.Value
, fromVBit, fromVFun, fromVPoly, fromVTuple, fromVRecord, lookupRecord
, fromSeq, fromVWord
, evalPanic
, iteValue, mergeValue
)
where
import Data.Bits (finiteBitSize)
import Cryptol.Eval.Value (TValue, numTValue, toNumTValue, finTValue, isTBit, isTFun, isTSeq, isTTuple, isTRec, tvSeq,
GenValue(..), BitWord(..), lam, tlam, toStream, toFinSeq, toSeq, fromSeq,
fromVBit, fromVWord, fromVFun, fromVPoly, fromVTuple, fromVRecord, lookupRecord)
import Cryptol.Symbolic.BitVector
import Cryptol.Utils.Panic (panic)
import Data.SBV (SBool, fromBitsBE, Mergeable(..), HasKind(..), EqSymbolic(..))
import Data.SBV.Internals (symbolicMergeWithKind)
import Data.SBV (Mergeable(..))
import Data.SBV.Internals (SBV(..))
import Data.SBV.Dynamic
-- Values ----------------------------------------------------------------------
@ -38,38 +38,46 @@ type Value = GenValue SBool SWord
-- Symbolic Conditionals -------------------------------------------------------
iteValue :: SBool -> Value -> Value -> Value
iteValue c x y =
case svAsBool c of
Just True -> x
Just False -> y
Nothing -> mergeValue True c x y
mergeValue :: Bool -> SBool -> Value -> Value -> Value
mergeValue f c v1 v2 =
case (v1, v2) of
(VRecord fs1, VRecord fs2) -> VRecord $ zipWith mergeField fs1 fs2
(VTuple vs1 , VTuple vs2 ) -> VTuple $ zipWith (mergeValue f c) vs1 vs2
(VBit b1 , VBit b2 ) -> VBit $ mergeBit b1 b2
(VWord w1 , VWord w2 ) -> VWord $ mergeWord w1 w2
(VSeq b1 vs1, VSeq _ vs2 ) -> VSeq b1 $ zipWith (mergeValue f c) vs1 vs2
(VStream vs1, VStream vs2) -> VStream $ mergeStream vs1 vs2
(VFun f1 , VFun f2 ) -> VFun $ \x -> mergeValue f c (f1 x) (f2 x)
(VPoly f1 , VPoly f2 ) -> VPoly $ \x -> mergeValue f c (f1 x) (f2 x)
(VWord w1 , _ ) -> VWord $ mergeWord w1 (fromVWord v2)
(_ , VWord w2 ) -> VWord $ mergeWord (fromVWord v1) w2
(_ , _ ) -> panic "Cryptol.Symbolic.Value"
[ "mergeValue: incompatible values" ]
where
mergeBit b1 b2 = svSymbolicMerge KBool f c b1 b2
mergeWord w1 w2 = svSymbolicMerge (svKind w1) f c w1 w2
mergeField (n1, x1) (n2, x2)
| n1 == n2 = (n1, mergeValue f c x1 x2)
| otherwise = panic "Cryptol.Symbolic.Value"
[ "mergeValue.mergeField: incompatible values" ]
mergeStream xs ys =
mergeValue f c (head xs) (head ys) : mergeStream (tail xs) (tail ys)
instance Mergeable Value where
symbolicMerge f c v1 v2 =
case (v1, v2) of
(VRecord fs1, VRecord fs2) -> VRecord $ zipWith mergeField fs1 fs2
(VTuple vs1 , VTuple vs2 ) -> VTuple $ zipWith (symbolicMerge f c) vs1 vs2
(VBit b1 , VBit b2 ) -> VBit $ symbolicMerge f c b1 b2
(VWord w1 , VWord w2 ) -> VWord $ mergeWord w1 w2
(VSeq b1 vs1, VSeq _ vs2 ) -> VSeq b1 $ symbolicMerge f c vs1 vs2
(VStream vs1, VStream vs2) -> VStream $ mergeStream vs1 vs2
(VFun f1 , VFun f2 ) -> VFun $ symbolicMerge f c f1 f2
(VPoly f1 , VPoly f2 ) -> VPoly $ symbolicMerge f c f1 f2
(VWord w1 , _ ) -> VWord $ mergeWord w1 (fromVWord v2)
(_ , VWord w2 ) -> VWord $ mergeWord (fromVWord v1) w2
(_ , _ ) -> panic "Cryptol.Symbolic.Value"
[ "symbolicMerge: incompatible values" ]
where
mergeWord w1 w2 = symbolicMergeWithKind (kindOf w1) f c w1 w2
mergeField (n1, x1) (n2, x2)
| n1 == n2 = (n1, symbolicMerge f c x1 x2)
| otherwise = panic "Cryptol.Symbolic.Value"
[ "symbolicMerge.mergeField: incompatible values" ]
mergeStream xs ys =
symbolicMerge f c (head xs) (head ys) : mergeStream (tail xs) (tail ys)
symbolicMerge f (SBV c) = mergeValue f c
-- Big-endian Words ------------------------------------------------------------
instance BitWord SBool SWord where
packWord bs = Data.SBV.fromBitsBE bs
unpackWord w = [ sbvTestBit' w i | i <- reverse [0 .. finiteBitSize w - 1] ]
sbvTestBit' :: SWord -> Int -> SBool
sbvTestBit' w i = extract i i w .== literalSWord 1 1
packWord bs = fromBitsLE (reverse bs)
unpackWord x = [ svTestBit x i | i <- reverse [0 .. svBitSize x - 1] ]
-- Errors ----------------------------------------------------------------------