start working on execMacawStmtExtension for new memory model

This commit is contained in:
Daniel Wagner 2020-03-09 23:46:58 -04:00
parent 5506e05486
commit f4daaa7e81

View File

@ -10,23 +10,41 @@
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ViewPatterns #-}
module Data.Macaw.Symbolic.MemOps2 where
import Control.Applicative
import Control.Lens ((%~), (&), (^.))
import Control.Monad.State
import Data.List
import Data.Sequence (Seq)
import qualified Data.Sequence as Seq
import qualified Data.Vector as V
import GHC.TypeNats (type (<=))
import GHC.TypeNats (type (<=), KnownNat)
import Numeric
import Data.Macaw.CFG.AssignRhs (MemRepr(..))
import Data.Macaw.CFG.AssignRhs (ArchAddrWidth, MemRepr(..))
import Data.Macaw.Memory (AddrWidthRepr(..), Endianness, addrWidthNatRepr)
import Data.Macaw.Symbolic.Backend (EvalStmtFunc, MacawArchEvalFn(..))
import Data.Macaw.Symbolic.CrucGen (MacawArchStmtExtension, MacawStmtExtension(..), MacawExt)
import Data.Macaw.Symbolic.MemOps (LookupFunctionHandle(..), MacawSimulatorState(..))
import Data.Macaw.Symbolic.PersistentState (ToCrucibleType)
import Lang.Crucible.Backend (IsSymInterface)
import Lang.Crucible.LLVM.MemModel (LLVMPtr, pattern LLVMPointer, llvmPointer_bv)
import Data.Parameterized.Context (pattern (:>), pattern Empty)
import Lang.Crucible.Backend (IsSymInterface, LabeledPred(..), assert)
import Lang.Crucible.CFG.Common (GlobalVar)
import Lang.Crucible.LLVM.MemModel (LLVMPointerType, LLVMPtr, pattern LLVMPointer, llvmPointer_bv)
import Lang.Crucible.Simulator.ExecutionTree (CrucibleState, actFrame, gpGlobals, stateSymInterface, stateTree)
import Lang.Crucible.Simulator.GlobalState (insertGlobal, lookupGlobal)
import Lang.Crucible.Simulator.Intrinsics (IntrinsicClass(..))
import Lang.Crucible.Simulator.RegMap (RegEntry(..))
import Lang.Crucible.Simulator.RegValue (RegValue)
import Lang.Crucible.Types (BoolType)
import Lang.Crucible.Simulator.SimError (SimErrorReason(..))
import Lang.Crucible.Types ((::>), BoolType, BVType, EmptyCtx, IntrinsicType, NatType, TypeRepr(BVRepr))
import What4.Concrete (ConcreteVal(..))
import What4.Expr.Builder (ExprBuilder)
import What4.Interface -- (NatRepr, knownRepr, BaseTypeRepr(..), SolverSymbol, userSymbol, freshConstant, natLit)
data MemOpCondition sym
@ -39,15 +57,223 @@ data MemOpDirection = Read | Write deriving (Bounded, Enum, Eq, Ord, Read, Show)
data MemOp sym ptrW where
MemOp ::
{ moAddr :: LLVMPtr sym ptrW
, moDir :: MemOpDirection
, moCond :: MemOpCondition sym
, moSize :: NatRepr w
, moVal :: LLVMPtr sym w
, moEnd :: Endianness
} -> MemOp sym ptrW
-- The address of the operation
LLVMPtr sym ptrW ->
MemOpDirection ->
MemOpCondition sym ->
-- The size of the operation in bits
NatRepr w ->
-- The value read or written during the operation
LLVMPtr sym w ->
Endianness ->
MemOp sym ptrW
MergeOps ::
Pred sym ->
MemImpl2 sym ptrW ->
MemImpl2 sym ptrW ->
MemOp sym ptrW
type MemImpl2 sym ptrW = [MemOp sym ptrW]
instance Eq (MemOpCondition (ExprBuilder t st fs)) where
Unconditional == Unconditional = True
Conditional p == Conditional p' = p == p'
_ == _ = False
instance Eq (MemOp (ExprBuilder t st fs) ptrW) where
MemOp (LLVMPointer addrR addrO) dir cond repr (LLVMPointer valR valO) end
== MemOp (LLVMPointer addrR' addrO') dir' cond' repr' (LLVMPointer valR' valO') end' = case testEquality repr repr' of
Nothing -> False
Just Refl -> addrR == addrR' && addrO == addrO' && dir == dir' && cond == cond' && valR == valR' && valO == valO' && end == end'
MergeOps p opsT opsF == MergeOps p' opsT' opsF' = p == p' && opsT == opsT' && opsF == opsF'
_ == _ = False
-- TODO: how about "trace memory model"-based naming instead of "2"-based naming?
type MemImpl2 sym ptrW = Seq (MemOp sym ptrW)
type Mem2 arch = IntrinsicType "Abstract_memory" (EmptyCtx ::> BVType (ArchAddrWidth arch))
memRepr2 :: (KnownNat (ArchAddrWidth arch), 1 <= ArchAddrWidth arch) => TypeRepr (Mem2 arch)
memRepr2 = knownRepr
instance IntrinsicClass (ExprBuilder t st fs) "Abstract_memory" where
-- TODO: cover other cases with a TypeError
type Intrinsic (ExprBuilder t st fs) "Abstract_memory" (EmptyCtx ::> BVType ptrW) = MemImpl2 (ExprBuilder t st fs) ptrW
muxIntrinsic _ _ _ (Empty :> BVRepr _) p l r = pure $ case Seq.spanl (uncurry (==)) (Seq.zip l r) of
(_, Seq.Empty) -> l
(eqs, Seq.unzip -> (l', r')) -> (fst <$> eqs) Seq.:|> MergeOps p l' r'
type MacawEvalStmtFunc sym arch = EvalStmtFunc (MacawStmtExtension arch) (MacawSimulatorState sym) sym (MacawExt arch)
type GlobalMap2 sym arch = sym -> MemImpl2 sym (ArchAddrWidth arch) -> RegValue sym NatType -> RegValue sym (BVType (ArchAddrWidth arch)) -> IO (Maybe (LLVMPtr sym (ArchAddrWidth arch)))
type MacawArchEvalFn2 sym arch = GlobalVar (Mem2 arch) -> GlobalMap2 sym arch -> EvalStmtFunc (MacawArchStmtExtension arch) (MacawSimulatorState sym) sym (MacawExt arch)
execMacawStmtExtension ::
forall sym arch t st fs. (IsSymInterface sym, KnownNat (ArchAddrWidth arch), sym ~ ExprBuilder t st fs) =>
MacawArchEvalFn2 sym arch ->
GlobalVar (Mem2 arch) ->
GlobalMap2 sym arch ->
MacawEvalStmtFunc sym arch
execMacawStmtExtension archStmtFn mvar globs stmt
= case stmt of
MacawReadMem addrWidth memRepr addr
-> liftToCrucibleState mvar $ \sym ->
doReadMem sym addrWidth (regValue addr) memRepr
MacawCondReadMem addrWidth memRepr cond addr def
-> liftToCrucibleState mvar $ \sym ->
doCondReadMem sym (regValue cond) (regValue def) addrWidth (regValue addr) memRepr
MacawWriteMem addrWidth memRepr addr val
-> liftToCrucibleState mvar $ \sym ->
doWriteMem sym addrWidth (regValue addr) (regValue val) memRepr
MacawCondWriteMem addrWidth memRepr cond addr def
-> liftToCrucibleState mvar $ \sym ->
doCondWriteMem sym (regValue cond) addrWidth (regValue addr) (regValue def) memRepr
MacawGlobalPtr w addr -> undefined
MacawFreshSymbolic t -> undefined
MacawLookupFunctionHandle typeReps registers -> undefined
MacawArchStmtExtension archStmt -> archStmtFn mvar globs archStmt
MacawArchStateUpdate{} -> \cst -> pure ((), cst)
MacawInstructionStart{} -> \cst -> pure ((), cst)
PtrEq w x y -> ptrOp w x y $ \sym reg off reg' off' -> do
regEq <- natEq sym reg reg'
offEq <- bvEq sym off off'
andPred sym regEq offEq
PtrLeq w x y -> ptrOp w x y $ \sym reg off reg' off' -> do
whoKnows <- ioFreshConstant sym "PtrLeq_across_allocations" knownRepr
regEq <- natEq sym reg reg'
offLeq <- bvUle sym off off'
itePred sym regEq offLeq whoKnows
PtrLt w x y -> ptrOp w x y $ \sym reg off reg' off' -> do
whoKnows <- ioFreshConstant sym "PtrLt_across_allocations" knownRepr
regEq <- natEq sym reg reg'
offLt <- bvUlt sym off off'
itePred sym regEq offLt whoKnows
PtrMux w (RegEntry _ p) x y -> ptrOp w x y $ \sym reg off reg' off' -> do
reg'' <- natIte sym p reg reg'
off'' <- bvIte sym p off off'
pure (LLVMPointer reg'' off'')
PtrAdd w x y -> ptrOp w x y $ \sym reg off reg' off' -> do
regZero <- isZero sym reg
regZero' <- isZero sym reg'
someZero <- orPred sym regZero regZero'
assert sym someZero $ AssertFailureSimError
"PtrAdd: expected ptr+constant, saw ptr+ptr"
"When doing pointer addition, we expect at least one of the two arguments to the addition to have a region of 0 (i.e. not be the result of allocating memory). Both arguments had non-0 regions."
reg'' <- cases sym
[ (pure regZero, pure reg')
, (pure regZero', pure reg)
]
(ioFreshConstant sym "PtrAdd_both_ptrs_region" knownRepr)
off'' <- cases sym
[ (pure someZero, bvAdd sym off off')
]
(ioFreshConstant sym "PtrAdd_both_ptrs_offset" knownRepr)
pure (LLVMPointer reg'' off'')
PtrSub w x y -> ptrOp w x y $ \sym reg off reg' off' -> do
regZero' <- isZero sym reg'
regEq <- natEq sym reg reg'
compatSub <- orPred sym regZero' regEq
assert sym compatSub $ AssertFailureSimError
"PtrSub: strange mix of allocation regions"
"When doing pointer subtraction, we expect that either the two pointers are from the same allocation region or the negated one is actually a constant. Other mixings of allocation regions have arbitrary behavior."
reg'' <- cases sym
[ (pure regZero', pure reg)
, (pure regEq, natLit sym 0)
]
(ioFreshConstant sym "PtrSub_region_mismatch" knownRepr)
off'' <- cases sym
[ (pure compatSub, bvSub sym off off')
]
(ioFreshConstant sym "PtrSub_region_mismatch" knownRepr)
pure (LLVMPointer reg'' off'')
PtrAnd w x y -> ptrOp w x y $ \sym reg off reg' off' -> do
-- TODO: assertions, like in PtrAdd and PtrSub
reg'' <- cases sym
[ (isZero sym reg, pure reg')
, (isZero sym reg', pure reg)
, (natEq sym reg reg', pure reg)
]
(ioFreshConstant sym "PtrAnd_across_allocations_region" knownRepr)
off'' <- cases sym
undefined
undefined
pure (LLVMPointer reg'' off'')
liftToCrucibleState ::
GlobalVar mem ->
(sym -> StateT (RegValue sym mem) IO a) ->
CrucibleState p sym ext rtp blocks r ctx ->
IO (a, CrucibleState p sym ext rtp blocks r ctx)
liftToCrucibleState mvar f cst = do
mem <- getGlobalVar cst mvar
(a, mem') <- runStateT (f (cst ^. stateSymInterface)) mem
pure (a, setGlobalVar cst mvar mem')
readOnlyWithSym ::
(sym -> IO a) ->
CrucibleState p sym ext rtp blocks r ctx ->
IO (a, CrucibleState p sym ext rtp blocks r ctx)
readOnlyWithSym f cst = flip (,) cst <$> f (cst ^. stateSymInterface)
getGlobalVar :: CrucibleState s sym ext rtp blocks r ctx -> GlobalVar mem -> IO (RegValue sym mem)
getGlobalVar cst gv = case lookupGlobal gv (cst ^. stateTree . actFrame . gpGlobals) of
Just val -> return val
Nothing -> fail ("Global variable not initialized: " ++ show gv)
setGlobalVar :: CrucibleState s sym ext rtp blocks r ctx -> GlobalVar mem -> RegValue sym mem -> CrucibleState s sym ext rtp blocks r ctx
setGlobalVar cst gv val = cst & stateTree . actFrame . gpGlobals %~ insertGlobal gv val
ptrOp ::
AddrWidthRepr w ->
RegEntry sym (LLVMPointerType w) ->
RegEntry sym (LLVMPointerType w) ->
(1 <= w => sym -> SymNat sym -> SymBV sym w -> SymNat sym -> SymBV sym w -> IO a) ->
CrucibleState p sym ext rtp blocks r ctx ->
IO (a, CrucibleState p sym ext rtp blocks r ctx)
ptrOp w (RegEntry _ (LLVMPointer region offset)) (RegEntry _ (LLVMPointer region' offset')) f =
addrWidthsArePositive w $ readOnlyWithSym $ \sym -> f sym region offset region' offset'
cases ::
IsExprBuilder sym =>
sym ->
[(IO (Pred sym), IO (SymExpr sym tp))] ->
IO (SymExpr sym tp) ->
IO (SymExpr sym tp)
cases sym branches def = go branches where
go [] = def
go ((iop, iov):bs) = do
p <- iop
vT <- iov
vF <- go bs
baseTypeIte sym p vT vF
isZero :: IsExprBuilder sym => sym -> SymNat sym -> IO (Pred sym)
isZero sym reg = do
zero <- natLit sym 0
natEq sym reg zero
andIOPred :: IsExprBuilder sym => sym -> IO (Pred sym) -> IO (Pred sym) -> IO (Pred sym)
andIOPred sym p1_ p2_ = do
p1 <- p1_
p2 <- p2_
andPred sym p1 p2
orIOPred :: IsExprBuilder sym => sym -> IO (Pred sym) -> IO (Pred sym) -> IO (Pred sym)
orIOPred sym p1_ p2_ = do
p1 <- p1_
p2 <- p2_
orPred sym p1 p2
doReadMem ::
IsSymInterface sym =>
@ -106,8 +332,9 @@ freshRegValue sym (LLVMPointer reg off) = go 0 where
go :: Integer -> MemRepr ty' -> IO (RegValue sym (ToCrucibleType ty'))
go n (BVMemRepr byteWidth _endianness) = do
let bitWidth = natMultiply (knownNat @8) byteWidth
symbolContent <- ioSolverSymbol . intercalate "_" $ describe byteWidth n
content <- multiplicationIsMonotonic @8 bitWidth $ freshConstant sym symbolContent (BaseBVRepr bitWidth)
content <- multiplicationIsMonotonic @8 bitWidth $ ioFreshConstant sym
(intercalate "_" $ describe byteWidth n)
(BaseBVRepr bitWidth)
llvmPointer_bv sym content
go _n (FloatMemRepr _infoRepr _endianness) = fail "creating fresh float values not supported in freshRegValue"
@ -134,14 +361,8 @@ doMemOpInternal :: forall sym ptrW ty.
doMemOpInternal sym dir cond ptrWidth = go where
go :: LLVMPtr sym ptrW -> RegValue sym (ToCrucibleType ty') -> MemRepr ty' -> StateT (MemImpl2 sym ptrW) IO ()
go ptr@(LLVMPointer reg off) regVal = \case
BVMemRepr byteWidth endianness -> logOp MemOp
{ moAddr = ptr
, moDir = dir
, moCond = cond
, moSize = natMultiply (knownNat @8) byteWidth
, moVal = regVal
, moEnd = endianness
}
BVMemRepr byteWidth endianness -> logOp $ MemOp ptr dir cond bitWidth regVal endianness
where bitWidth = natMultiply (knownNat @8) byteWidth
FloatMemRepr _infoRepr _endianness -> fail "reading floats not supported in doMemOpInternal"
PackedVecMemRepr _countRepr recRepr -> addrWidthsArePositive ptrWidth $ do
elemSize <- liftIO $ bvLit sym ptrWidthNatRepr (memReprByteSize recRepr)
@ -174,7 +395,7 @@ iteDeep sym cond t f = \case
iteDeep sym cond (t V.! i) (f V.! i) recRepr
logOp :: (MonadState (MemImpl2 sym ptrW) m) => MemOp sym ptrW -> m ()
logOp op = modify (op:)
logOp op = modify (Seq.:|> op)
addrWidthsArePositive :: AddrWidthRepr w -> (1 <= w => a) -> a
addrWidthsArePositive Addr32 a = a
@ -192,3 +413,8 @@ memReprByteSize (PackedVecMemRepr countRepr recRepr) = intValue countRepr * memR
ioSolverSymbol :: String -> IO SolverSymbol
ioSolverSymbol = either (fail . show) pure . userSymbol
ioFreshConstant :: IsSymExprBuilder sym => sym -> String -> BaseTypeRepr tp -> IO (SymExpr sym tp)
ioFreshConstant sym nm ty = do
symbol <- ioSolverSymbol nm
freshConstant sym symbol ty