Use the post-withSetup DepState

It's wrong to use initialDepState when there is a setup action, as the
action could end with a DepState which is not the same as the initial
one.  Here's an example of it going wrong:

    > :{
    resultsSet defaultWay defaultMemType $ do
      v <- newMVar ()
      fork (takeMVar v)
      readMVar v
    :}
    fromList [Left Deadlock,Right ()]

    > :{
    resultsSet defaultWay defaultMemType $
      withSetup (newMVar ()) $ \v -> do
        fork (takeMVar v)
        readMVar v
    :}
    fromList [Right ()]

This PR pushes responsibility for the DepState into the Context, and
the DepState is passed to all schedulers.  That means it's been
promoted from an internal type to a user-facing one, so I gave it the
more generic name "ConcurrencyState".  Furthermore, the snapshotted
DepState is passed to all the DPOR functions, and the trace
simplification functions.

initialDepState is now only used:

- in Conc.Internal to initialise a new context
- in SCT.Internal when there is no snapshot
This commit is contained in:
Michael Walker 2019-02-12 19:16:18 +00:00
parent 5d5a6ef2ff
commit cb118e4f41
12 changed files with 330 additions and 234 deletions

View File

@ -22,6 +22,7 @@ import Test.DejaFu (Condition, Predicate,
ProPredicate(..), Result(..), Way,
alwaysTrue, somewhereTrue)
import Test.DejaFu.Conc (randomSched, runConcurrent)
import Test.DejaFu.Internal
import qualified Test.DejaFu.SCT as SCT
import Test.DejaFu.SCT.Internal
import Test.DejaFu.Types
@ -123,10 +124,13 @@ prop_dep_fun safeIO conc = H.property $ do
seed <- H.forAll genInt
fs <- H.forAll $ genList HGen.bool
-- todo: 1 1 is not right if a snapshot is restored
-- todo: this doesn't work with setup actions that (a) fork a
-- thread or (b) make an IORef. this is because it permutes the
-- trace using the initialCState, rather than the post-setup
-- state.
(efa1, tids1, efa2, tids2) <- liftIO $ runNorm
seed
(renumber mem 1 1 . permuteBy safeIO mem (map (\f _ _ -> f) fs))
(renumber mem 1 1 . permuteBy safeIO mem initialCState (map (\f _ _ -> f) fs))
mem
H.footnote (" to: " ++ show tids2)
H.footnote ("rewritten from: " ++ show tids1)
@ -138,7 +142,7 @@ prop_dep_fun safeIO conc = H.property $ do
let tids1 = toTIdTrace trc1
(efa2, _, trc2) <- replay (play memtype conc) (norm tids1)
let tids2 = toTIdTrace trc2
pure (efa1, map fst tids1, efa2, map fst tids2)
pure (efa1, tids1, efa2, tids2)
play memtype c s g = runConcurrent s memtype g c

View File

@ -326,6 +326,11 @@ programTests = toTestList
writeIORef x 1
pure x)
(\x -> takeMVar =<< spawn (readIORef x))
, djfuTS "MVar state is preserved from setup action" (gives [Left Deadlock, Right ()]) $
withSetup (newMVar ()) $ \v -> do
_ <- fork $ takeMVar v
readMVar v
]
, testGroup "withSetupAndTeardown"

View File

@ -245,14 +245,14 @@ memoryProps = toTestList
sctProps :: [TestTree]
sctProps = toTestList
[ testProperty "canInterrupt ==> canInterruptL" $ do
ds <- H.forAll genDepState
ds <- H.forAll genCState
tid <- H.forAll genThreadId
act <- H.forAll (HGen.filter (SCT.canInterrupt ds tid) genThreadAction)
H.assert (SCT.canInterruptL ds tid (D.rewind act))
act <- H.forAll (HGen.filter (D.canInterrupt ds tid) genThreadAction)
H.assert (D.canInterruptL ds tid (D.rewind act))
, testProperty "dependent ==> dependent'" $ do
safeIO <- H.forAll HGen.bool
ds <- H.forAll genDepState
ds <- H.forAll genCState
tid1 <- H.forAll genThreadId
tid2 <- H.forAll genThreadId
ta1 <- H.forAll genThreadAction
@ -261,7 +261,7 @@ sctProps = toTestList
, testProperty "dependent x y == dependent y x" $ do
safeIO <- H.forAll HGen.bool
ds <- H.forAll genDepState
ds <- H.forAll genCState
tid1 <- H.forAll genThreadId
tid2 <- H.forAll genThreadId
ta1 <- H.forAll genThreadAction
@ -269,7 +269,7 @@ sctProps = toTestList
SCT.dependent safeIO ds tid1 ta1 tid2 ta2 H.=== SCT.dependent safeIO ds tid2 ta2 tid1 ta1
, testProperty "dependentActions x y == dependentActions y x" $ do
ds <- H.forAll genDepState
ds <- H.forAll genCState
a1 <- H.forAll genActionType
a2 <- H.forAll genActionType
SCT.dependentActions ds a1 a2 H.=== SCT.dependentActions ds a2 a1
@ -435,8 +435,8 @@ genSynchronisedActionType = HGen.choice
, pure D.SynchronisedOther
]
genDepState :: H.Gen SCT.DepState
genDepState = SCT.DepState
genCState :: H.Gen D.ConcurrencyState
genCState = D.ConcurrencyState
<$> genSmallMap genIORefId genSmallInt
<*> genSmallSet genMVarId
<*> genSmallMap genThreadId genMaskingState

View File

@ -47,9 +47,22 @@ Added
* ``Test.DejaFu.runTestWithSettings`` function.
* A simplified form of the concurrency state:
* ``Test.DejaFu.Types.ConcurrencyState``
* ``Test.DejaFu.Types.isBuffered``
* ``Test.DejaFu.Types.numBuffered``
* ``Test.DejaFu.Types.isFull``
* ``Test.DejaFu.Types.canInterrupt``
* ``Test.DejaFu.Types.canInterruptL``
* ``Test.DejaFu.Types.isMaskedInterruptible``
* ``Test.DejaFu.Types.isMaskedUninterruptible``
Changed
~~~~~~~
* ``Test.DejaFu.Schedule.Scheduler`` has a ``ConcurrencyState``
parameter.
* ``Test.DejaFu.alwaysSameBy`` and ``Test.DejaFu.notAlwaysSameBy``
return a representative trace for each unique condition.

View File

@ -79,6 +79,7 @@ runConcurrency invariants forSnapshot sched memtype g idsrc caps ma = do
, cCaps = caps
, cInvariants = InvariantContext { icActive = invariants, icBlocked = [] }
, cNewInvariants = []
, cCState = initialCState
}
(c, ref) <- runRefCont AStop (Just . Right) (runModelConc ma)
let threads0 = launch' Unmasked initialThread (const c) (cThreads ctx)
@ -125,6 +126,7 @@ data Context n g = Context
, cCaps :: Int
, cInvariants :: InvariantContext n
, cNewInvariants :: [Invariant n ()]
, cCState :: ConcurrencyState
}
-- | Run a collection of threads, until there are no threads left.
@ -171,7 +173,7 @@ runThreads forSnapshot sched memtype ref = schedule (const $ pure ()) Seq.empty
Nothing -> E.throwM ScheduledMissingThread
Nothing -> die Abort restore sofar prior ctx'
where
(choice, g') = scheduleThread sched prior (efromList runnable') (cSchedState ctx)
(choice, g') = scheduleThread sched prior (efromList runnable') (cCState ctx) (cSchedState ctx)
runnable' = [(t, lookahead (_continuation a)) | (t, a) <- sortOn fst $ M.assocs runnable]
runnable = M.filter (not . isBlocked) threadsc
threadsc = addCommitThreads (cWriteBuf ctx) threads
@ -197,7 +199,7 @@ runThreads forSnapshot sched memtype ref = schedule (const $ pure ()) Seq.empty
if forSnapshot
then restore threads' >> actionSnap threads'
else restore threads'
let ctx' = fixContext chosen actOrTrc res ctx
let ctx' = fixContext memtype chosen actOrTrc res ctx
case res of
Succeeded _ -> checkInvariants (cInvariants ctx') >>= \case
Right ic ->
@ -214,22 +216,21 @@ runThreads forSnapshot sched memtype ref = schedule (const $ pure ()) Seq.empty
getPrior a = Just (chosen, a)
-- | Apply the context update from stepping an action.
fixContext :: ThreadId -> ThreadAction -> What n g -> Context n g -> Context n g
fixContext chosen act (Succeeded ctx@Context{..}) _ =
ctx { cThreads = delCommitThreads $
if (interruptible <$> M.lookup chosen cThreads) /= Just False
then unblockWaitingOn chosen cThreads
else cThreads
, cInvariants = unblockInvariants act cInvariants
}
fixContext _ act (Failed _) ctx@Context{..} =
ctx { cThreads = delCommitThreads cThreads
, cInvariants = unblockInvariants act cInvariants
}
fixContext _ act (Snap ctx@Context{..}) _ =
ctx { cThreads = delCommitThreads cThreads
, cInvariants = unblockInvariants act cInvariants
fixContext :: MemType -> ThreadId -> ThreadAction -> What n g -> Context n g -> Context n g
fixContext memtype tid act what ctx0 = fixContextCommon $ case what of
Succeeded ctx@Context{..} -> ctx
{ cThreads =
if (interruptible <$> M.lookup tid cThreads) /= Just False
then unblockWaitingOn tid cThreads
else cThreads
}
_ -> ctx0
where
fixContextCommon ctx@Context{..} = ctx
{ cThreads = delCommitThreads cThreads
, cInvariants = unblockInvariants act cInvariants
, cCState = updateCState memtype cCState tid act
}
-- | @unblockWaitingOn tid@ unblocks every thread blocked in a
-- @throwTo tid@.

View File

@ -2,23 +2,26 @@
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
-- |
-- Module : Test.DejaFu.Internal
-- Copyright : (c) 2017--2018 Michael Walker
-- Copyright : (c) 2017--2019 Michael Walker
-- License : MIT
-- Maintainer : Michael Walker <mike@barrucadu.co.uk>
-- Stability : experimental
-- Portability : DeriveAnyClass, DeriveGeneric, FlexibleContexts, GADTs
-- Portability : DeriveAnyClass, DeriveGeneric, FlexibleContexts, GADTs, LambdaCase
--
-- Internal types and functions used throughout DejaFu. This module
-- is NOT considered to form part of the public interface of this
-- library.
module Test.DejaFu.Internal where
import Control.DeepSeq (NFData)
import Control.DeepSeq (NFData(..))
import Control.Exception (MaskingState(..))
import qualified Control.Monad.Conc.Class as C
import Data.List.NonEmpty (NonEmpty(..))
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
import Data.Maybe (fromMaybe)
import Data.Set (Set)
@ -323,6 +326,58 @@ simplifyLookahead WillSTM = SynchronisedOther
simplifyLookahead (WillThrowTo _) = SynchronisedOther
simplifyLookahead _ = UnsynchronisedOther
-------------------------------------------------------------------------------
-- * Concurrency state
-- | Initial concurrency state.
initialCState :: ConcurrencyState
initialCState = ConcurrencyState M.empty S.empty M.empty
-- | Update the concurrency state with the action that has just
-- happened.
updateCState :: MemType -> ConcurrencyState -> ThreadId -> ThreadAction -> ConcurrencyState
updateCState memtype cstate tid act = ConcurrencyState
{ concIOState = updateIOState memtype act $ concIOState cstate
, concMVState = updateMVState act $ concMVState cstate
, concMaskState = updateMaskState tid act $ concMaskState cstate
}
-- | Update the @IORef@ buffer state with the action that has just
-- happened.
updateIOState :: MemType -> ThreadAction -> Map IORefId Int -> Map IORefId Int
updateIOState SequentialConsistency _ = const M.empty
updateIOState _ (CommitIORef _ r) = (`M.alter` r) $ \case
Just 1 -> Nothing
Just n -> Just (n-1)
Nothing -> Nothing
updateIOState _ (WriteIORef r) = M.insertWith (+) r 1
updateIOState _ ta
| isBarrier $ simplifyAction ta = const M.empty
| otherwise = id
-- | Update the @MVar@ full/empty state with the action that has just
-- happened.
updateMVState :: ThreadAction -> Set MVarId -> Set MVarId
updateMVState (PutMVar mvid _) = S.insert mvid
updateMVState (TryPutMVar mvid True _) = S.insert mvid
updateMVState (TakeMVar mvid _) = S.delete mvid
updateMVState (TryTakeMVar mvid True _) = S.delete mvid
updateMVState _ = id
-- | Update the thread masking state with the action that has just
-- happened.
updateMaskState :: ThreadId -> ThreadAction -> Map ThreadId MaskingState -> Map ThreadId MaskingState
updateMaskState tid (Fork tid2) = \masks -> case M.lookup tid masks of
-- A thread inherits the masking state of its parent.
Just ms -> M.insert tid2 ms masks
Nothing -> masks
updateMaskState tid (SetMasking _ ms) = M.insert tid ms
updateMaskState tid (ResetMasking _ ms) = M.insert tid ms
updateMaskState tid (Throw True) = M.delete tid
updateMaskState _ (ThrowTo tid True) = M.delete tid
updateMaskState tid Stop = M.delete tid
updateMaskState _ _ = id
-------------------------------------------------------------------------------
-- * Error reporting

View File

@ -131,13 +131,13 @@ runSCTWithSettings settings conc = case _way settings of
check = findSchedulePrefix
step run dp (prefix, conservative, sleep) = do
step cstate0 run dp (prefix, conservative, sleep) = do
(res, s, trace) <- run
(dporSched (_safeIO settings) (_memtype settings) (cBound (_lengthBound settings) cb0))
(initialDPORSchedState sleep prefix)
(dporSched (_safeIO settings) (cBound (_lengthBound settings) cb0))
(initialDPORSchedState sleep prefix cstate0)
let bpoints = findBacktrackSteps (_safeIO settings) (_memtype settings) (cBacktrack cb0) (schedBoundKill s) (schedBPoints s) trace
let newDPOR = incorporateTrace (_safeIO settings) (_memtype settings) conservative trace dp
let bpoints = findBacktrackSteps (_safeIO settings) (_memtype settings) (cBacktrack cb0) (schedBoundKill s) cstate0 (schedBPoints s) trace
let newDPOR = incorporateTrace (_safeIO settings) (_memtype settings) conservative trace cstate0 dp
pure $ if schedIgnore s
then (force newDPOR, Nothing)
@ -150,7 +150,7 @@ runSCTWithSettings settings conc = case _way settings of
check (_, 0) = Nothing
check s = Just s
step run _ (g, n) = do
step _ run _ (g, n) = do
(res, s, trace) <- run
(randSched gen)
(initialRandSchedState (_lengthBound settings) g)

View File

@ -43,7 +43,7 @@ sct :: (MonadConc n, HasCallStack)
-- ^ Initial state
-> (s -> Maybe t)
-- ^ State predicate
-> ((Scheduler g -> g -> n (Either Condition a, g, Trace)) -> s -> t -> n (s, Maybe (Either Condition a, Trace)))
-> (ConcurrencyState -> (Scheduler g -> g -> n (Either Condition a, g, Trace)) -> s -> t -> n (s, Maybe (Either Condition a, Trace)))
-- ^ Run the computation and update the state
-> Program pty n a
-> n [(Either Condition a, Trace)]
@ -54,21 +54,26 @@ sct settings s0 sfun srun conc = recordSnapshot conc >>= \case
where
sct'Full = sct'
settings
initialCState
(s0 [initialThread])
sfun
(srun runFull)
(srun initialCState runFull)
runFull
(toId 1)
(toId 1)
sct'Snap snap = let idsrc = cIdSource (contextFromSnapshot snap) in sct'
settings
(s0 (fst (threadsFromSnapshot snap)))
sfun
(srun (runSnap snap))
(runSnap snap)
(toId $ 1 + fst (_tids idsrc))
(toId $ 1 + fst (_iorids idsrc))
sct'Snap snap =
let idsrc = cIdSource (contextFromSnapshot snap)
cstate = cCState (contextFromSnapshot snap)
in sct'
settings
cstate
(s0 (fst (threadsFromSnapshot snap)))
sfun
(srun cstate (runSnap snap))
(runSnap snap)
(toId $ 1 + fst (_tids idsrc))
(toId $ 1 + fst (_iorids idsrc))
runFull sched s = runConcurrent sched (_memtype settings) s conc
runSnap snap sched s = runSnapshot sched (_memtype settings) s snap
@ -77,6 +82,8 @@ sct settings s0 sfun srun conc = recordSnapshot conc >>= \case
sct' :: (MonadConc n, HasCallStack)
=> Settings n a
-- ^ The SCT settings ('Way' is ignored)
-> ConcurrencyState
-- ^ The initial concurrency state
-> s
-- ^ Initial state
-> (s -> Maybe t)
@ -90,7 +97,7 @@ sct' :: (MonadConc n, HasCallStack)
-> IORefId
-- ^ The first available @IORefId@
-> n [(Either Condition a, Trace)]
sct' settings s0 sfun srun run nTId nCRId = go Nothing [] s0 where
sct' settings cstate0 s0 sfun srun run nTId nCRId = go Nothing [] s0 where
go (Just res) _ _ | earlyExit res = pure []
go res0 seen !s = case sfun s of
Just t -> srun s t >>= \case
@ -122,7 +129,7 @@ sct' settings s0 sfun srun run nTId nCRId = go Nothing [] s0 where
dosimplify res trace seen s
| not (_simplify settings) = ((res, trace) :) <$> go (Just res) seen s
| otherwise = do
shrunk <- simplifyExecution settings run nTId nCRId res trace
shrunk <- simplifyExecution settings cstate0 run nTId nCRId res trace
(shrunk :) <$> go (Just res) seen s
earlyExit = fromMaybe (const False) (_earlyExit settings)
@ -146,6 +153,8 @@ sct' settings s0 sfun srun run nTId nCRId = go Nothing [] s0 where
simplifyExecution :: (MonadConc n, HasCallStack)
=> Settings n a
-- ^ The SCT settings ('Way' is ignored)
-> ConcurrencyState
-- ^ The initial concurrency state
-> (forall x. Scheduler x -> x -> n (Either Condition a, x, Trace))
-- ^ Just run the computation
-> ThreadId
@ -156,7 +165,7 @@ simplifyExecution :: (MonadConc n, HasCallStack)
-- ^ The expected result
-> Trace
-> n (Either Condition a, Trace)
simplifyExecution settings run nTId nCRId res trace
simplifyExecution settings cstate0 run nTId nCRId res trace
| tidTrace == simplifiedTrace = do
debugPrint ("Simplifying new result '" ++ p res ++ "': no simplification possible!")
pure (res, trace)
@ -172,7 +181,7 @@ simplifyExecution settings run nTId nCRId res trace
pure (res, trace)
where
tidTrace = toTIdTrace trace
simplifiedTrace = simplify (_safeIO settings) (_memtype settings) tidTrace
simplifiedTrace = simplify (_safeIO settings) (_memtype settings) cstate0 tidTrace
fixup = renumber (_memtype settings) (fromId nTId) (fromId nCRId)
debugFatal = if _debugFatal settings then fatal else debugPrint
@ -188,11 +197,11 @@ replay :: MonadConc n
-- ^ The reduced sequence of scheduling decisions
-> n (Either Condition a, [(ThreadId, ThreadAction)], Trace)
replay run = run (Scheduler (const sched)) where
sched runnable ((t, Stop):ts) = case findThread t runnable of
sched runnable cs ((t, Stop):ts) = case findThread t runnable of
Just t' -> (Just t', ts)
Nothing -> sched runnable ts
sched runnable ((t, _):ts) = (findThread t runnable, ts)
sched _ _ = (Nothing, [])
Nothing -> sched runnable cs ts
sched runnable _ ((t, _):ts) = (findThread t runnable, ts)
sched _ _ _ = (Nothing, [])
-- find a thread ignoring names
findThread tid0 =
@ -203,10 +212,15 @@ replay run = run (Scheduler (const sched)) where
-- | Simplify a trace by permuting adjacent independent actions to
-- reduce context switching.
simplify :: Bool -> MemType -> [(ThreadId, ThreadAction)] -> [(ThreadId, ThreadAction)]
simplify safeIO memtype trc0 = loop (length trc0) (prepare trc0) where
prepare = dropCommits safeIO memtype . lexicoNormalForm safeIO memtype
step = pushForward safeIO memtype . pullBack safeIO memtype
simplify
:: Bool
-> MemType
-> ConcurrencyState
-> [(ThreadId, ThreadAction)]
-> [(ThreadId, ThreadAction)]
simplify safeIO memtype cstate0 trc0 = loop (length trc0) (prepare trc0) where
prepare = dropCommits safeIO memtype cstate0 . lexicoNormalForm safeIO memtype cstate0
step = pushForward safeIO memtype cstate0 . pullBack safeIO memtype cstate0
loop 0 trc = trc
loop n trc =
@ -214,10 +228,15 @@ simplify safeIO memtype trc0 = loop (length trc0) (prepare trc0) where
in if trc' /= trc then loop (n-1) trc' else trc
-- | Put a trace into lexicographic (by thread ID) normal form.
lexicoNormalForm :: Bool -> MemType -> [(ThreadId, ThreadAction)] -> [(ThreadId, ThreadAction)]
lexicoNormalForm safeIO memtype = go where
lexicoNormalForm
:: Bool
-> MemType
-> ConcurrencyState
-> [(ThreadId, ThreadAction)]
-> [(ThreadId, ThreadAction)]
lexicoNormalForm safeIO memtype cstate0 = go where
go trc =
let trc' = permuteBy safeIO memtype (repeat (>)) trc
let trc' = permuteBy safeIO memtype cstate0 (repeat (>)) trc
in if trc == trc' then trc else go trc'
-- | Swap adjacent independent actions in the trace if a predicate
@ -225,25 +244,31 @@ lexicoNormalForm safeIO memtype = go where
permuteBy
:: Bool
-> MemType
-> ConcurrencyState
-> [ThreadId -> ThreadId -> Bool]
-> [(ThreadId, ThreadAction)]
-> [(ThreadId, ThreadAction)]
permuteBy safeIO memtype = go initialDepState where
permuteBy safeIO memtype = go where
go ds (p:ps) (t1@(tid1, ta1):t2@(tid2, ta2):trc)
| independent safeIO ds tid1 ta1 tid2 ta2 && p tid1 tid2 = go' ds ps t2 (t1 : trc)
| otherwise = go' ds ps t1 (t2 : trc)
go _ _ trc = trc
go' ds ps t@(tid, ta) trc = t : go (updateDepState memtype ds tid ta) ps trc
go' ds ps t@(tid, ta) trc = t : go (updateCState memtype ds tid ta) ps trc
-- | Throw away commit actions which are followed by a memory barrier.
dropCommits :: Bool -> MemType -> [(ThreadId, ThreadAction)] -> [(ThreadId, ThreadAction)]
dropCommits _ SequentialConsistency = id
dropCommits safeIO memtype = go initialDepState where
dropCommits
:: Bool
-> MemType
-> ConcurrencyState
-> [(ThreadId, ThreadAction)]
-> [(ThreadId, ThreadAction)]
dropCommits _ SequentialConsistency = const id
dropCommits safeIO memtype = go where
go ds (t1@(tid1, ta1@(CommitIORef _ iorefid)):t2@(tid2, ta2):trc)
| isBarrier (simplifyAction ta2) && numBuffered ds iorefid == 1 = go ds (t2:trc)
| independent safeIO ds tid1 ta1 tid2 ta2 = t2 : go (updateDepState memtype ds tid2 ta2) (t1:trc)
go ds (t@(tid,ta):trc) = t : go (updateDepState memtype ds tid ta) trc
| independent safeIO ds tid1 ta1 tid2 ta2 = t2 : go (updateCState memtype ds tid2 ta2) (t1:trc)
go ds (t@(tid,ta):trc) = t : go (updateCState memtype ds tid ta) trc
go _ [] = []
-- | Attempt to reduce context switches by \"pulling\" thread actions
@ -253,10 +278,15 @@ dropCommits safeIO memtype = go initialDepState where
-- act3)]@, where @act2@ and @act3@ are independent. In this case
-- 'pullBack' will swap them, giving the sequence @[(tidA, act1),
-- (tidA, act3), (tidB, act2)]@. It works for arbitrary separations.
pullBack :: Bool -> MemType -> [(ThreadId, ThreadAction)] -> [(ThreadId, ThreadAction)]
pullBack safeIO memtype = go initialDepState where
pullBack
:: Bool
-> MemType
-> ConcurrencyState
-> [(ThreadId, ThreadAction)]
-> [(ThreadId, ThreadAction)]
pullBack safeIO memtype = go where
go ds (t1@(tid1, ta1):trc@((tid2, _):_)) =
let ds' = updateDepState memtype ds tid1 ta1
let ds' = updateCState memtype ds tid1 ta1
trc' = if tid1 /= tid2
then maybe trc (uncurry (:)) (findAction tid1 ds' trc)
else trc
@ -266,7 +296,7 @@ pullBack safeIO memtype = go initialDepState where
findAction tid0 = fgo where
fgo ds (t@(tid, ta):trc)
| tid == tid0 = Just (t, trc)
| otherwise = case fgo (updateDepState memtype ds tid ta) trc of
| otherwise = case fgo (updateCState memtype ds tid ta) trc of
Just (ft@(ftid, fa), trc')
| independent safeIO ds tid ta ftid fa -> Just (ft, t:trc')
_ -> Nothing
@ -282,10 +312,15 @@ pullBack safeIO memtype = go initialDepState where
-- act3)]@, where @act1@ and @act2@ are independent. In this case
-- 'pushForward' will swap them, giving the sequence @[(tidB, act2),
-- (tidA, act1), (tidA, act3)]@. It works for arbitrary separations.
pushForward :: Bool -> MemType -> [(ThreadId, ThreadAction)] -> [(ThreadId, ThreadAction)]
pushForward safeIO memtype = go initialDepState where
pushForward
:: Bool
-> MemType
-> ConcurrencyState
-> [(ThreadId, ThreadAction)]
-> [(ThreadId, ThreadAction)]
pushForward safeIO memtype = go where
go ds (t1@(tid1, ta1):trc@((tid2, _):_)) =
let ds' = updateDepState memtype ds tid1 ta1
let ds' = updateCState memtype ds tid1 ta1
in if tid1 /= tid2
then maybe (t1 : go ds' trc) (go ds) (findAction tid1 ta1 ds trc)
else t1 : go ds' trc
@ -294,7 +329,7 @@ pushForward safeIO memtype = go initialDepState where
findAction tid0 ta0 = fgo where
fgo ds (t@(tid, ta):trc)
| tid == tid0 = Just ((tid0, ta0) : t : trc)
| independent safeIO ds tid0 ta0 tid ta = (t:) <$> fgo (updateDepState memtype ds tid ta) trc
| independent safeIO ds tid0 ta0 tid ta = (t:) <$> fgo (updateCState memtype ds tid ta) trc
| otherwise = Nothing
fgo _ _ = Nothing

View File

@ -1,7 +1,6 @@
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ViewPatterns #-}
-- |
@ -10,7 +9,7 @@
-- License : MIT
-- Maintainer : Michael Walker <mike@barrucadu.co.uk>
-- Stability : experimental
-- Portability : DeriveAnyClass, DeriveGeneric, FlexibleContexts, LambdaCase, ViewPatterns
-- Portability : DeriveAnyClass, DeriveGeneric, FlexibleContexts, ViewPatterns
--
-- Internal types and functions for SCT via dynamic partial-order
-- reduction. This module is NOT considered to form part of the
@ -18,8 +17,7 @@
module Test.DejaFu.SCT.Internal.DPOR where
import Control.Applicative ((<|>))
import Control.DeepSeq (NFData(..))
import Control.Exception (MaskingState(..))
import Control.DeepSeq (NFData)
import qualified Data.Foldable as F
import Data.Function (on)
import Data.List (nubBy, partition, sortOn)
@ -104,7 +102,7 @@ data BacktrackStep = BacktrackStep
, bcktBacktracks :: Map ThreadId Bool
-- ^ The list of alternative threads to run, and whether those
-- alternatives were added conservatively due to the bound.
, bcktState :: DepState
, bcktState :: ConcurrencyState
-- ^ Some domain-specific state at this point.
} deriving (Eq, Show, Generic, NFData)
@ -164,12 +162,14 @@ incorporateTrace :: HasCallStack
-> Trace
-- ^ The execution trace: the decision made, the runnable threads,
-- and the action performed.
-> ConcurrencyState
-- ^ The initial concurrency state
-> DPOR
-> DPOR
incorporateTrace safeIO memtype conservative trace dpor0 = grow initialDepState (initialDPORThread dpor0) trace dpor0 where
incorporateTrace safeIO memtype conservative trace state0 dpor0 = grow state0 (initialDPORThread dpor0) trace dpor0 where
grow state tid trc@((d, _, a):rest) dpor =
let tid' = tidOf tid d
state' = updateDepState memtype state tid' a
state' = updateCState memtype state tid' a
in case dporNext dpor of
Just (t, child)
| t == tid' ->
@ -190,7 +190,7 @@ incorporateTrace safeIO memtype conservative trace dpor0 = grow initialDepState
-- Construct a new subtree corresponding to a trace suffix.
subtree state tid sleep ((_, _, a):rest) = validateDPOR $
let state' = updateDepState memtype state tid a
let state' = updateCState memtype state tid a
sleep' = M.filterWithKey (\t a' -> not $ dependent safeIO state' tid a t a') sleep
in DPOR
{ dporRunnable = S.fromList $ case rest of
@ -235,6 +235,8 @@ findBacktrackSteps
-> Bool
-- ^ Whether the computation was aborted due to no decisions being
-- in-bounds.
-> ConcurrencyState
-- ^ The initial concurrency state.
-> Seq ([(ThreadId, Lookahead)], [ThreadId])
-- ^ A sequence of threads at each step: the list of runnable
-- in-bound threads (with lookahead values), and the list of threads
@ -244,12 +246,12 @@ findBacktrackSteps
-> Trace
-- ^ The execution trace.
-> [BacktrackStep]
findBacktrackSteps safeIO memtype backtrack boundKill = go initialDepState S.empty initialThread [] . F.toList where
findBacktrackSteps safeIO memtype backtrack boundKill state0 = go state0 S.empty initialThread [] . F.toList where
-- Walk through the traces one step at a time, building up a list of
-- new backtracking points.
go state allThreads tid bs ((e,i):is) ((d,_,a):ts) =
let tid' = tidOf tid d
state' = updateDepState memtype state tid' a
state' = updateCState memtype state tid' a
this = BacktrackStep
{ bcktThreadid = tid'
, bcktDecision = d
@ -344,7 +346,7 @@ data DPORSchedState k = DPORSchedState
, schedBoundKill :: Bool
-- ^ Whether the execution was terminated due to all decisions being
-- out of bounds.
, schedDepState :: DepState
, schedCState :: ConcurrencyState
-- ^ State used by the dependency function to determine when to
-- remove decisions from the sleep set.
, schedBState :: Maybe k
@ -356,14 +358,16 @@ initialDPORSchedState :: Map ThreadId ThreadAction
-- ^ The initial sleep set.
-> [ThreadId]
-- ^ The schedule prefix.
-> ConcurrencyState
-- ^ The initial concurrency state.
-> DPORSchedState k
initialDPORSchedState sleep prefix = DPORSchedState
initialDPORSchedState sleep prefix state0 = DPORSchedState
{ schedSleep = sleep
, schedPrefix = prefix
, schedBPoints = Sq.empty
, schedIgnore = False
, schedBoundKill = False
, schedDepState = initialDepState
, schedCState = state0
, schedBState = Nothing
}
@ -442,19 +446,22 @@ backtrackAt toAll bs0 = backtrackAt' . nubBy ((==) `on` fst') . sortOn fst' wher
dporSched :: HasCallStack
=> Bool
-- ^ True if all IO is thread safe.
-> MemType
-> IncrementalBoundFunc k
-- ^ Bound function: returns true if that schedule prefix terminated
-- with the lookahead decision fits within the bound.
-> Scheduler (DPORSchedState k)
dporSched safeIO memtype boundf = Scheduler $ \prior threads s ->
dporSched safeIO boundf = Scheduler $ \prior threads cstate s ->
let
-- The next scheduler state
nextState rest = s
{ schedBPoints = schedBPoints s |> (restrictToBound fst threads', rest)
, schedDepState = nextDepState
-- we only update this after using the current value; so in
-- effect this field is the depstate *before* the action which
-- just happened, we need this because we need to know if the
-- prior action (in the state we did it from) is dependent with
-- anything in the sleep set.
, schedCState = cstate
}
nextDepState = let ds = schedDepState s in maybe ds (uncurry $ updateDepState memtype ds) prior
-- Pick a new thread to run, not considering bounds. Choose the
-- current thread if available and it hasn't just yielded,
@ -522,7 +529,7 @@ dporSched safeIO memtype boundf = Scheduler $ \prior threads s ->
[] ->
let choices = restrictToBound id initialise
checkDep t a = case prior of
Just (tid, act) -> dependent safeIO (schedDepState s) tid act t a
Just (tid, act) -> dependent safeIO (schedCState s) tid act t a
Nothing -> False
ssleep' = M.filterWithKey (\t a -> not $ checkDep t a) $ schedSleep s
choices' = filter (`notElem` M.keys ssleep') choices
@ -542,7 +549,7 @@ dporSched safeIO memtype boundf = Scheduler $ \prior threads s ->
--
-- This implements a stronger check that @not (dependent ...)@, as it
-- handles some cases which 'dependent' doesn't need to care about.
independent :: Bool -> DepState -> ThreadId -> ThreadAction -> ThreadId -> ThreadAction -> Bool
independent :: Bool -> ConcurrencyState -> ThreadId -> ThreadAction -> ThreadId -> ThreadAction -> Bool
independent safeIO ds t1 a1 t2 a2
| t1 == t2 = False
| check t1 a1 t2 a2 = False
@ -569,7 +576,7 @@ independent safeIO ds t1 a1 t2 a2
-- This is basically the same as 'dependent'', but can make use of the
-- additional information in a 'ThreadAction' to make better decisions
-- in a few cases.
dependent :: Bool -> DepState -> ThreadId -> ThreadAction -> ThreadId -> ThreadAction -> Bool
dependent :: Bool -> ConcurrencyState -> ThreadId -> ThreadAction -> ThreadId -> ThreadAction -> Bool
dependent safeIO ds t1 a1 t2 a2 = case (a1, a2) of
-- When masked interruptible, a thread can only be interrupted when
-- actually blocked. 'dependent'' has to assume that all
@ -601,7 +608,7 @@ dependent safeIO ds t1 a1 t2 a2 = case (a1, a2) of
--
-- Termination of the initial thread is handled specially in the DPOR
-- implementation.
dependent' :: Bool -> DepState -> ThreadId -> ThreadAction -> ThreadId -> Lookahead -> Bool
dependent' :: Bool -> ConcurrencyState -> ThreadId -> ThreadAction -> ThreadId -> Lookahead -> Bool
dependent' safeIO ds t1 a1 t2 l2 = case (a1, l2) of
-- Worst-case assumption: all IO is dependent.
(LiftIO, WillLiftIO) -> not safeIO
@ -630,7 +637,7 @@ dependent' safeIO ds t1 a1 t2 l2 = case (a1, l2) of
-- | Check if two 'ActionType's are dependent. Note that this is not
-- sufficient to know if two 'ThreadAction's are dependent, without
-- being so great an over-approximation as to be useless!
dependentActions :: DepState -> ActionType -> ActionType -> Bool
dependentActions :: ConcurrencyState -> ActionType -> ActionType -> Bool
dependentActions ds a1 a2 = case (a1, a2) of
(UnsynchronisedRead _, UnsynchronisedRead _) -> False
@ -664,131 +671,6 @@ dependentActions ds a1 a2 = case (a1, a2) of
(_, _) -> maybe False (\r -> Just r == iorefOf a2) (iorefOf a1)
-------------------------------------------------------------------------------
-- ** Dependency function state
data DepState = DepState
{ depIOState :: Map IORefId Int
-- ^ Keep track of which @IORef@s have buffered writes.
, depMVState :: Set MVarId
-- ^ Keep track of which @MVar@s are full.
, depMaskState :: Map ThreadId MaskingState
-- ^ Keep track of thread masking states. If a thread isn't present,
-- the masking state is assumed to be @Unmasked@. This nicely
-- provides compatibility with dpor-0.1, where the thread IDs are
-- not available.
} deriving (Eq, Show)
instance NFData DepState where
rnf depstate = rnf ( depIOState depstate
, depMVState depstate
, [(t, m `seq` ()) | (t, m) <- M.toList (depMaskState depstate)]
)
-- | Initial dependency state.
initialDepState :: DepState
initialDepState = DepState M.empty S.empty M.empty
-- | Update the dependency state with the action that has just
-- happened.
updateDepState :: MemType -> DepState -> ThreadId -> ThreadAction -> DepState
updateDepState memtype depstate tid act = DepState
{ depIOState = updateIOState memtype act $ depIOState depstate
, depMVState = updateMVState act $ depMVState depstate
, depMaskState = updateMaskState tid act $ depMaskState depstate
}
-- | Update the @IORef@ buffer state with the action that has just
-- happened.
updateIOState :: MemType -> ThreadAction -> Map IORefId Int -> Map IORefId Int
updateIOState SequentialConsistency _ = const M.empty
updateIOState _ (CommitIORef _ r) = (`M.alter` r) $ \case
Just 1 -> Nothing
Just n -> Just (n-1)
Nothing -> Nothing
updateIOState _ (WriteIORef r) = M.insertWith (+) r 1
updateIOState _ ta
| isBarrier $ simplifyAction ta = const M.empty
| otherwise = id
-- | Update the @MVar@ full/empty state with the action that has just
-- happened.
updateMVState :: ThreadAction -> Set MVarId -> Set MVarId
updateMVState (PutMVar mvid _) = S.insert mvid
updateMVState (TryPutMVar mvid True _) = S.insert mvid
updateMVState (TakeMVar mvid _) = S.delete mvid
updateMVState (TryTakeMVar mvid True _) = S.delete mvid
updateMVState _ = id
-- | Update the thread masking state with the action that has just
-- happened.
updateMaskState :: ThreadId -> ThreadAction -> Map ThreadId MaskingState -> Map ThreadId MaskingState
updateMaskState tid (Fork tid2) = \masks -> case M.lookup tid masks of
-- A thread inherits the masking state of its parent.
Just ms -> M.insert tid2 ms masks
Nothing -> masks
updateMaskState tid (SetMasking _ ms) = M.insert tid ms
updateMaskState tid (ResetMasking _ ms) = M.insert tid ms
updateMaskState tid (Throw True) = M.delete tid
updateMaskState _ (ThrowTo tid True) = M.delete tid
updateMaskState tid Stop = M.delete tid
updateMaskState _ _ = id
-- | Check if a @IORef@ has a buffered write pending.
isBuffered :: DepState -> IORefId -> Bool
isBuffered depstate r = numBuffered depstate r /= 0
-- | Check how many buffered writes an @IORef@ has.
numBuffered :: DepState -> IORefId -> Int
numBuffered depstate r = M.findWithDefault 0 r (depIOState depstate)
-- | Check if an @MVar@ is full.
isFull :: DepState -> MVarId -> Bool
isFull depstate v = S.member v (depMVState depstate)
-- | Check if an exception can interrupt a thread (action).
canInterrupt :: DepState -> ThreadId -> ThreadAction -> Bool
canInterrupt depstate tid act
-- If masked interruptible, blocked actions can be interrupted.
| isMaskedInterruptible depstate tid = case act of
BlockedPutMVar _ -> True
BlockedReadMVar _ -> True
BlockedTakeMVar _ -> True
BlockedSTM _ -> True
BlockedThrowTo _ -> True
_ -> False
-- If masked uninterruptible, nothing can be.
| isMaskedUninterruptible depstate tid = False
-- If no mask, anything can be.
| otherwise = True
-- | Check if an exception can interrupt a thread (lookahead).
canInterruptL :: DepState -> ThreadId -> Lookahead -> Bool
canInterruptL depstate tid lh
-- If masked interruptible, actions which can block may be
-- interrupted.
| isMaskedInterruptible depstate tid = case lh of
WillPutMVar _ -> True
WillReadMVar _ -> True
WillTakeMVar _ -> True
WillSTM -> True
WillThrowTo _ -> True
_ -> False
-- If masked uninterruptible, nothing can be.
| isMaskedUninterruptible depstate tid = False
-- If no mask, anything can be.
| otherwise = True
-- | Check if a thread is masked interruptible.
isMaskedInterruptible :: DepState -> ThreadId -> Bool
isMaskedInterruptible depstate tid =
M.lookup tid (depMaskState depstate) == Just MaskedInterruptible
-- | Check if a thread is masked uninterruptible.
isMaskedUninterruptible :: DepState -> ThreadId -> Bool
isMaskedUninterruptible depstate tid =
M.lookup tid (depMaskState depstate) == Just MaskedUninterruptible
-------------------------------------------------------------------------------
-- * Utilities

View File

@ -45,7 +45,7 @@ initialRandSchedState = RandSchedState M.empty
-- and makes a weighted random choice out of the runnable threads at
-- every step.
randSched :: RandomGen g => (g -> (Int, g)) -> Scheduler (RandSchedState g)
randSched weightf = Scheduler $ \_ threads s ->
randSched weightf = Scheduler $ \_ threads _ s ->
let
-- Select a thread
pick idx ((x, f):xs)

View File

@ -37,16 +37,19 @@ import Test.DejaFu.Types
--
-- 2. The unblocked threads.
--
-- 3. The state.
-- 3. The concurrency state.
--
-- 4. The scheduler state.
--
-- It returns a thread to execute, or @Nothing@ if execution should
-- abort here, and also a new state.
--
-- @since 0.8.0.0
-- @since 2.0.0.0
newtype Scheduler state = Scheduler
{ scheduleThread
:: Maybe (ThreadId, ThreadAction)
-> NonEmpty (ThreadId, Lookahead)
-> ConcurrencyState
-> state
-> (Maybe ThreadId, state)
}
@ -60,7 +63,7 @@ newtype Scheduler state = Scheduler
-- @since 0.8.0.0
randomSched :: RandomGen g => Scheduler g
randomSched = Scheduler go where
go _ threads g =
go _ threads _ g =
let threads' = map fst (toList threads)
(choice, g') = randomR (0, length threads' - 1) g
in (Just $ eidx threads' choice, g')
@ -71,8 +74,8 @@ randomSched = Scheduler go where
-- @since 0.8.0.0
roundRobinSched :: Scheduler ()
roundRobinSched = Scheduler go where
go Nothing ((tid,_):|_) _ = (Just tid, ())
go (Just (prior, _)) threads _ =
go Nothing ((tid,_):|_) _ _ = (Just tid, ())
go (Just (prior, _)) threads _ _ =
let threads' = map fst (toList threads)
candidates =
if prior >= maximum threads'
@ -109,7 +112,7 @@ roundRobinSchedNP = makeNonPreemptive roundRobinSched
-- @since 0.8.0.0
makeNonPreemptive :: Scheduler s -> Scheduler s
makeNonPreemptive sched = Scheduler newsched where
newsched p@(Just (prior, _)) threads s
newsched p@(Just (prior, _)) threads cs s
| prior `elem` map fst (toList threads) = (Just prior, s)
| otherwise = scheduleThread sched p threads s
newsched Nothing threads s = scheduleThread sched Nothing threads s
| otherwise = scheduleThread sched p threads cs s
newsched Nothing threads cs s = scheduleThread sched Nothing threads cs s

View File

@ -5,7 +5,7 @@
-- |
-- Module : Test.DejaFu.Types
-- Copyright : (c) 2017--2018 Michael Walker
-- Copyright : (c) 2017--2019 Michael Walker
-- License : MIT
-- Maintainer : Michael Walker <mike@barrucadu.co.uk>
-- Stability : experimental
@ -21,7 +21,11 @@ import Control.Exception (Exception(..),
import Data.Function (on)
import Data.Functor.Contravariant (Contravariant(..))
import Data.Functor.Contravariant.Divisible (Divisible(..))
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
import Data.Semigroup (Semigroup(..))
import Data.Set (Set)
import qualified Data.Set as S
import GHC.Generics (Generic)
-------------------------------------------------------------------------------
@ -721,3 +725,97 @@ deriving instance Generic MonadFailException
-- | @since 1.3.1.0
instance NFData MonadFailException
-------------------------------------------------------------------------------
-- ** Concurrency state
-- | A summary of the concurrency state of the program.
--
-- @since 2.0.0.0
data ConcurrencyState = ConcurrencyState
{ concIOState :: Map IORefId Int
-- ^ Keep track of which @IORef@s have buffered writes.
, concMVState :: Set MVarId
-- ^ Keep track of which @MVar@s are full.
, concMaskState :: Map ThreadId MaskingState
-- ^ Keep track of thread masking states. If a thread isn't present,
-- the masking state is assumed to be @Unmasked@. This nicely
-- provides compatibility with dpor-0.1, where the thread IDs are
-- not available.
} deriving (Eq, Show)
instance NFData ConcurrencyState where
rnf cstate = rnf
( concIOState cstate
, concMVState cstate
, [(t, show m) | (t, m) <- M.toList (concMaskState cstate)]
)
-- | Check if a @IORef@ has a buffered write pending.
--
-- @since 2.0.0.0
isBuffered :: ConcurrencyState -> IORefId -> Bool
isBuffered cstate r = numBuffered cstate r /= 0
-- | Check how many buffered writes an @IORef@ has.
--
-- @since 2.0.0.0
numBuffered :: ConcurrencyState -> IORefId -> Int
numBuffered cstate r = M.findWithDefault 0 r (concIOState cstate)
-- | Check if an @MVar@ is full.
--
-- @since 2.0.0.0
isFull :: ConcurrencyState -> MVarId -> Bool
isFull cstate v = S.member v (concMVState cstate)
-- | Check if an exception can interrupt a thread (action).
--
-- @since 2.0.0.0
canInterrupt :: ConcurrencyState -> ThreadId -> ThreadAction -> Bool
canInterrupt cstate tid act
-- If masked interruptible, blocked actions can be interrupted.
| isMaskedInterruptible cstate tid = case act of
BlockedPutMVar _ -> True
BlockedReadMVar _ -> True
BlockedTakeMVar _ -> True
BlockedSTM _ -> True
BlockedThrowTo _ -> True
_ -> False
-- If masked uninterruptible, nothing can be.
| isMaskedUninterruptible cstate tid = False
-- If no mask, anything can be.
| otherwise = True
-- | Check if an exception can interrupt a thread (lookahead).
--
-- @since 2.0.0.0
canInterruptL :: ConcurrencyState -> ThreadId -> Lookahead -> Bool
canInterruptL cstate tid lh
-- If masked interruptible, actions which can block may be
-- interrupted.
| isMaskedInterruptible cstate tid = case lh of
WillPutMVar _ -> True
WillReadMVar _ -> True
WillTakeMVar _ -> True
WillSTM -> True
WillThrowTo _ -> True
_ -> False
-- If masked uninterruptible, nothing can be.
| isMaskedUninterruptible cstate tid = False
-- If no mask, anything can be.
| otherwise = True
-- | Check if a thread is masked interruptible.
--
-- @since 2.0.0.0
isMaskedInterruptible :: ConcurrencyState -> ThreadId -> Bool
isMaskedInterruptible cstate tid =
M.lookup tid (concMaskState cstate) == Just MaskedInterruptible
-- | Check if a thread is masked uninterruptible.
--
-- @since 2.0.0.0
isMaskedUninterruptible :: ConcurrencyState -> ThreadId -> Bool
isMaskedUninterruptible cstate tid =
M.lookup tid (concMaskState cstate) == Just MaskedUninterruptible