Merge branch 'subconcurrency'

This commit is contained in:
Michael Walker 2017-02-18 20:45:15 +00:00
commit 75d2b6ca73
12 changed files with 555 additions and 613 deletions

View File

@ -11,6 +11,7 @@ import Test.HUnit.DejaFu (testDejafu)
import Control.Concurrent.Classy
import Control.Monad.STM.Class
import Test.DejaFu.Conc (Conc, subconcurrency)
#if __GLASGOW_HASKELL__ < 710
import Control.Applicative ((<$>), (<*>))
@ -49,6 +50,13 @@ tests =
, testGroup "Daemons" . hUnitTestToTests $ test
[ testDejafu schedDaemon "schedule daemon" $ gives' [0,1]
]
, testGroup "Subconcurrency" . hUnitTestToTests $ test
[ testDejafu scDeadlock1 "deadlock1" $ gives' [Left Deadlock, Right ()]
, testDejafu scDeadlock2 "deadlock2" $ gives' [(Left Deadlock, ()), (Right (), ())]
, testDejafu scSuccess "success" $ gives' [Right ()]
, testDejafu scIllegal "illegal" $ gives [Left IllegalSubconcurrency]
]
]
--------------------------------------------------------------------------------
@ -207,3 +215,40 @@ schedDaemon = do
x <- newCRef 0
_ <- fork $ myThreadId >> writeCRef x 1
readCRef x
--------------------------------------------------------------------------------
-- Subconcurrency
-- | Subcomputation deadlocks sometimes.
scDeadlock1 :: Monad n => Conc n r (Either Failure ())
scDeadlock1 = do
var <- newEmptyMVar
subconcurrency $ do
void . fork $ putMVar var ()
putMVar var ()
-- | Subcomputation deadlocks sometimes, and action after it still
-- happens.
scDeadlock2 :: Monad n => Conc n r (Either Failure (), ())
scDeadlock2 = do
var <- newEmptyMVar
res <- subconcurrency $ do
void . fork $ putMVar var ()
putMVar var ()
(,) <$> pure res <*> readMVar var
-- | Subcomputation successfully completes.
scSuccess :: Monad n => Conc n r (Either Failure ())
scSuccess = do
var <- newMVar ()
subconcurrency $ do
out <- newEmptyMVar
void . fork $ takeMVar var >>= putMVar out
takeMVar out
-- | Illegal usage
scIllegal :: Monad n => Conc n r ()
scIllegal = do
var <- newEmptyMVar
void . fork $ readMVar var
void . subconcurrency $ pure ()

View File

@ -3,6 +3,7 @@
module Cases.SingleThreaded where
import Control.Exception (ArithException(..), ArrayException(..))
import Control.Monad (void)
import Test.DejaFu (Failure(..), gives, gives')
import Test.Framework (Test, testGroup)
import Test.Framework.Providers.HUnit (hUnitTestToTests)
@ -10,7 +11,7 @@ import Test.HUnit (test)
import Test.HUnit.DejaFu (testDejafu)
import Control.Concurrent.Classy
import Control.Monad.STM.Class
import Test.DejaFu.Conc (Conc, subconcurrency)
import Utils
@ -58,6 +59,12 @@ tests =
[ testDejafu capsGet "get" $ gives' [True]
, testDejafu capsSet "set" $ gives' [True]
]
, testGroup "Subconcurrency" . hUnitTestToTests $ test
[ testDejafu scDeadlock1 "deadlock1" $ gives' [Left Deadlock]
, testDejafu scDeadlock2 "deadlock2" $ gives' [(Left Deadlock, ())]
, testDejafu scSuccess "success" $ gives' [Right ()]
]
]
--------------------------------------------------------------------------------
@ -252,3 +259,22 @@ capsSet = do
caps <- getNumCapabilities
setNumCapabilities $ caps + 1
(== caps + 1) <$> getNumCapabilities
--------------------------------------------------------------------------------
-- Subconcurrency
-- | Subcomputation deadlocks.
scDeadlock1 :: Monad n => Conc n r (Either Failure ())
scDeadlock1 = subconcurrency (newEmptyMVar >>= readMVar)
-- | Subcomputation deadlocks, and action after it still happens.
scDeadlock2 :: Monad n => Conc n r (Either Failure (), ())
scDeadlock2 = do
var <- newMVar ()
(,) <$> subconcurrency (putMVar var ()) <*> readMVar var
-- | Subcomputation successfully completes.
scSuccess :: Monad n => Conc n r (Either Failure ())
scSuccess = do
var <- newMVar ()
subconcurrency (takeMVar var)

View File

@ -240,7 +240,6 @@ module Test.DejaFu
) where
import Control.Arrow (first)
import Control.DeepSeq (NFData(..))
import Control.Monad (when, unless)
import Control.Monad.Ref (MonadRef)
import Control.Monad.ST (runST)
@ -400,9 +399,6 @@ defaultFail failures = Result False 0 failures ""
defaultPass :: Result a
defaultPass = Result True 0 [] ""
instance NFData a => NFData (Result a) where
rnf r = rnf (_pass r, _casesChecked r, _failures r, _failureMsg r)
instance Functor Result where
fmap f r = r { _failures = map (first $ fmap f) $ _failures r }

View File

@ -60,7 +60,6 @@ module Test.DejaFu.Common
, MemType(..)
) where
import Control.DeepSeq (NFData(..))
import Control.Exception (MaskingState(..))
import Data.Dynamic (Dynamic)
import Data.List (sort, nub, intercalate)
@ -83,9 +82,6 @@ instance Show ThreadId where
show (ThreadId (Just n) _) = n
show (ThreadId Nothing i) = show i
instance NFData ThreadId where
rnf (ThreadId n i) = rnf (n, i)
-- | Every @CRef@ has a unique identifier.
data CRefId = CRefId (Maybe String) Int
deriving Eq
@ -97,9 +93,6 @@ instance Show CRefId where
show (CRefId (Just n) _) = n
show (CRefId Nothing i) = show i
instance NFData CRefId where
rnf (CRefId n i) = rnf (n, i)
-- | Every @MVar@ has a unique identifier.
data MVarId = MVarId (Maybe String) Int
deriving Eq
@ -111,9 +104,6 @@ instance Show MVarId where
show (MVarId (Just n) _) = n
show (MVarId Nothing i) = show i
instance NFData MVarId where
rnf (MVarId n i) = rnf (n, i)
-- | Every @TVar@ has a unique identifier.
data TVarId = TVarId (Maybe String) Int
deriving Eq
@ -125,9 +115,6 @@ instance Show TVarId where
show (TVarId (Just n) _) = n
show (TVarId Nothing i) = show i
instance NFData TVarId where
rnf (TVarId n i) = rnf (n, i)
-- | The ID of the initial thread.
initialThread :: ThreadId
initialThread = ThreadId (Just "main") 0
@ -272,38 +259,10 @@ data ThreadAction =
-- ^ A '_concMessage' annotation was processed.
| Stop
-- ^ Cease execution and terminate.
| Subconcurrency
-- ^ Start executing an action with @subconcurrency@.
deriving Show
instance NFData ThreadAction where
rnf (Fork t) = rnf t
rnf (GetNumCapabilities i) = rnf i
rnf (SetNumCapabilities i) = rnf i
rnf (NewVar c) = rnf c
rnf (PutVar c ts) = rnf (c, ts)
rnf (BlockedPutVar c) = rnf c
rnf (TryPutVar c b ts) = rnf (c, b, ts)
rnf (ReadVar c) = rnf c
rnf (BlockedReadVar c) = rnf c
rnf (TakeVar c ts) = rnf (c, ts)
rnf (BlockedTakeVar c) = rnf c
rnf (TryTakeVar c b ts) = rnf (c, b, ts)
rnf (NewRef c) = rnf c
rnf (ReadRef c) = rnf c
rnf (ReadRefCas c) = rnf c
rnf (ModRef c) = rnf c
rnf (ModRefCas c) = rnf c
rnf (WriteRef c) = rnf c
rnf (CasRef c b) = rnf (c, b)
rnf (CommitRef t c) = rnf (t, c)
rnf (STM s ts) = rnf (s, ts)
rnf (BlockedSTM s) = rnf s
rnf (ThrowTo t) = rnf t
rnf (BlockedThrowTo t) = rnf t
rnf (SetMasking b m) = b `seq` m `seq` ()
rnf (ResetMasking b m) = b `seq` m `seq` ()
rnf (Message m) = m `seq` ()
rnf a = a `seq` ()
-- | Check if a @ThreadAction@ immediately blocks.
isBlock :: ThreadAction -> Bool
isBlock (BlockedThrowTo _) = True
@ -401,28 +360,10 @@ data Lookahead =
-- ^ Will process a _concMessage' annotation.
| WillStop
-- ^ Will cease execution and terminate.
| WillSubconcurrency
-- ^ Will execute an action with @subconcurrency@.
deriving Show
instance NFData Lookahead where
rnf (WillSetNumCapabilities i) = rnf i
rnf (WillPutVar c) = rnf c
rnf (WillTryPutVar c) = rnf c
rnf (WillReadVar c) = rnf c
rnf (WillTakeVar c) = rnf c
rnf (WillTryTakeVar c) = rnf c
rnf (WillReadRef c) = rnf c
rnf (WillReadRefCas c) = rnf c
rnf (WillModRef c) = rnf c
rnf (WillModRefCas c) = rnf c
rnf (WillWriteRef c) = rnf c
rnf (WillCasRef c) = rnf c
rnf (WillCommitRef t c) = rnf (t, c)
rnf (WillThrowTo t) = rnf t
rnf (WillSetMasking b m) = b `seq` m `seq` ()
rnf (WillResetMasking b m) = b `seq` m `seq` ()
rnf (WillMessage m) = m `seq` ()
rnf l = l `seq` ()
-- | Convert a 'ThreadAction' into a 'Lookahead': \"rewind\" what has
-- happened. 'Killed' has no 'Lookahead' counterpart.
rewind :: ThreadAction -> Maybe Lookahead
@ -462,6 +403,7 @@ rewind LiftIO = Just WillLiftIO
rewind Return = Just WillReturn
rewind (Message m) = Just (WillMessage m)
rewind Stop = Just WillStop
rewind Subconcurrency = Just WillSubconcurrency
-- | Check if an operation could enable another thread.
willRelease :: Lookahead -> Bool
@ -508,17 +450,6 @@ data ActionType =
-- communication.
deriving (Eq, Show)
instance NFData ActionType where
rnf (UnsynchronisedRead r) = rnf r
rnf (UnsynchronisedWrite r) = rnf r
rnf (PartiallySynchronisedCommit r) = rnf r
rnf (PartiallySynchronisedWrite r) = rnf r
rnf (PartiallySynchronisedModify r) = rnf r
rnf (SynchronisedModify r) = rnf r
rnf (SynchronisedRead c) = rnf c
rnf (SynchronisedWrite c) = rnf c
rnf a = a `seq` ()
-- | Check if an action imposes a write barrier.
isBarrier :: ActionType -> Bool
isBarrier (SynchronisedModify _) = True
@ -611,13 +542,6 @@ data TAction =
-- ^ Terminate successfully and commit effects.
deriving (Eq, Show)
instance NFData TAction where
rnf (TRead v) = rnf v
rnf (TWrite v) = rnf v
rnf (TCatch s m) = rnf (s, m)
rnf (TOrElse s m) = rnf (s, m)
rnf a = a `seq` ()
-------------------------------------------------------------------------------
-- Traces
@ -640,11 +564,6 @@ data Decision =
-- ^ Pre-empt the running thread, and switch to another.
deriving (Eq, Show)
instance NFData Decision where
rnf (Start tid) = rnf tid
rnf (SwitchTo tid) = rnf tid
rnf d = d `seq` ()
-- | Pretty-print a trace, including a key of the thread IDs (not
-- including thread 0). Each line of the key is indented by two
-- spaces.
@ -675,15 +594,15 @@ showTrace trc = intercalate "\n" $ concatMap go trc : strkey where
preEmpCount :: [(Decision, ThreadAction)]
-> (Decision, Lookahead)
-> Int
preEmpCount ts (d, _) = go initialThread Nothing ts where
go _ (Just Yield) ((SwitchTo t, a):rest) = go t (Just a) rest
go tid prior ((SwitchTo t, a):rest)
preEmpCount (x:xs) (d, _) = go initialThread x xs where
go _ (_, Yield) (r@(SwitchTo t, _):rest) = go t r rest
go tid prior (r@(SwitchTo t, _):rest)
| isCommitThread t = go tid prior (skip rest)
| otherwise = 1 + go t (Just a) rest
go _ _ ((Start t, a):rest) = go t (Just a) rest
go tid _ ((Continue, a):rest) = go tid (Just a) rest
| otherwise = 1 + go t r rest
go _ _ (r@(Start t, _):rest) = go t r rest
go tid _ (r@(Continue, _):rest) = go tid r rest
go _ prior [] = case (prior, d) of
(Just Yield, SwitchTo _) -> 0
((_, Yield), SwitchTo _) -> 0
(_, SwitchTo _) -> 1
_ -> 0
@ -694,6 +613,7 @@ preEmpCount ts (d, _) = go initialThread Nothing ts where
skip = dropWhile (not . isContextSwitch . fst)
isContextSwitch Continue = False
isContextSwitch _ = True
preEmpCount [] _ = 0
-------------------------------------------------------------------------------
-- Failures
@ -716,11 +636,11 @@ data Failure =
-- ^ The computation became blocked indefinitely on @TVar@s.
| UncaughtException
-- ^ An uncaught exception bubbled to the top of the computation.
| IllegalSubconcurrency
-- ^ Calls to @subconcurrency@ were nested, or attempted when
-- multiple threads existed.
deriving (Eq, Show, Read, Ord, Enum, Bounded)
instance NFData Failure where
rnf f = f `seq` ()
-- | Pretty-print a failure
showFail :: Failure -> String
showFail Abort = "[abort]"
@ -728,6 +648,7 @@ showFail Deadlock = "[deadlock]"
showFail STMDeadlock = "[stm-deadlock]"
showFail InternalError = "[internal-error]"
showFail UncaughtException = "[exception]"
showFail IllegalSubconcurrency = "[illegal-subconcurrency]"
-------------------------------------------------------------------------------
-- Memory Models
@ -751,9 +672,6 @@ data MemType =
-- created.
deriving (Eq, Show, Read, Ord, Enum, Bounded)
instance NFData MemType where
rnf m = m `seq` ()
-------------------------------------------------------------------------------
-- Utilities

View File

@ -28,6 +28,7 @@ module Test.DejaFu.Conc
, Failure(..)
, MemType(..)
, runConcurrent
, subconcurrency
-- * Execution traces
, Trace
@ -49,12 +50,11 @@ import Control.Exception (MaskingState(..))
import qualified Control.Monad.Base as Ba
import qualified Control.Monad.Catch as Ca
import qualified Control.Monad.IO.Class as IO
import Control.Monad.Ref (MonadRef, newRef, readRef, writeRef)
import Control.Monad.Ref (MonadRef,)
import Control.Monad.ST (ST)
import Data.Dynamic (toDyn)
import qualified Data.Foldable as F
import Data.IORef (IORef)
import qualified Data.Map.Strict as M
import Data.Maybe (fromJust)
import Data.STRef (STRef)
import Test.DejaFu.Schedule
@ -62,13 +62,12 @@ import qualified Control.Monad.Conc.Class as C
import Test.DejaFu.Common
import Test.DejaFu.Conc.Internal
import Test.DejaFu.Conc.Internal.Common
import Test.DejaFu.Conc.Internal.Threading
import Test.DejaFu.STM
{-# ANN module ("HLint: ignore Avoid lambda" :: String) #-}
{-# ANN module ("HLint: ignore Use const" :: String) #-}
newtype Conc n r a = C { unC :: M n r (STMLike n r) a } deriving (Functor, Applicative, Monad)
newtype Conc n r a = C { unC :: M n r a } deriving (Functor, Applicative, Monad)
-- | A 'MonadConc' implementation using @ST@, this should be preferred
-- if you do not need 'liftIO'.
@ -77,10 +76,10 @@ type ConcST t = Conc (ST t) (STRef t)
-- | A 'MonadConc' implementation using @IO@.
type ConcIO = Conc IO IORef
toConc :: ((a -> Action n r (STMLike n r)) -> Action n r (STMLike n r)) -> Conc n r a
toConc :: ((a -> Action n r) -> Action n r) -> Conc n r a
toConc = C . cont
wrap :: (M n r (STMLike n r) a -> M n r (STMLike n r) a) -> Conc n r a -> Conc n r a
wrap :: (M n r a -> M n r a) -> Conc n r a -> Conc n r a
wrap f = C . f . unC
instance IO.MonadIO ConcIO where
@ -181,20 +180,15 @@ runConcurrent :: MonadRef r n
-> s
-> Conc n r a
-> n (Either Failure a, s, Trace)
runConcurrent sched memtype s (C conc) = do
ref <- newRef Nothing
runConcurrent sched memtype s ma = do
(res, s', trace) <- runConcurrency sched memtype s (unC ma)
pure (res, s', F.toList trace)
let c = runCont conc (AStop . writeRef ref . Just . Right)
let threads = launch' Unmasked initialThread (const c) M.empty
(s', trace) <- runThreads runTransaction
sched
memtype
s
threads
initialIdSource
ref
out <- readRef ref
pure (fromJust out, s', reverse trace)
-- | Run a concurrent computation and return its result.
--
-- This can only be called in the main thread, when no other threads
-- exist. Calls to 'subconcurrency' cannot be nested. Violating either
-- of these conditions will result in the computation failing with
-- @IllegalSubconcurrency@.
subconcurrency :: Conc n r a -> Conc n r (Either Failure a)
subconcurrency ma = toConc (ASub (unC ma))

View File

@ -16,19 +16,23 @@
module Test.DejaFu.Conc.Internal where
import Control.Exception (MaskingState(..), toException)
import Control.Monad.Ref (MonadRef, newRef, writeRef)
import Control.Monad.Ref (MonadRef, newRef, readRef, writeRef)
import qualified Data.Foldable as F
import Data.Functor (void)
import Data.List (sort)
import Data.List.NonEmpty (NonEmpty(..), fromList)
import qualified Data.Map.Strict as M
import Data.Maybe (fromJust, isJust, isNothing, listToMaybe)
import Data.Maybe (fromJust, isJust, isNothing)
import Data.Monoid ((<>))
import Data.Sequence (Seq, (<|))
import qualified Data.Sequence as Seq
import Test.DejaFu.Common
import Test.DejaFu.Conc.Internal.Common
import Test.DejaFu.Conc.Internal.Memory
import Test.DejaFu.Conc.Internal.Threading
import Test.DejaFu.Schedule
import Test.DejaFu.STM (Result(..))
import Test.DejaFu.STM (Result(..), runTransaction)
{-# ANN module ("HLint: ignore Use record patterns" :: String) #-}
{-# ANN module ("HLint: ignore Use const" :: String) #-}
@ -36,40 +40,69 @@ import Test.DejaFu.STM (Result(..))
--------------------------------------------------------------------------------
-- * Execution
-- | 'Trace' but as a sequence.
type SeqTrace
= Seq (Decision, [(ThreadId, NonEmpty Lookahead)], ThreadAction)
-- | Run a concurrent computation with a given 'Scheduler' and initial
-- state, returning a failure reason on error. Also returned is the
-- final state of the scheduler, and an execution trace.
runConcurrency :: MonadRef r n
=> Scheduler g
-> MemType
-> g
-> M n r a
-> n (Either Failure a, g, SeqTrace)
runConcurrency sched memtype g ma = do
ref <- newRef Nothing
let c = runCont ma (AStop . writeRef ref . Just . Right)
let threads = launch' Unmasked initialThread (const c) M.empty
let ctx = Context { cSchedState = g, cIdSource = initialIdSource, cThreads = threads, cWriteBuf = emptyBuffer, cCaps = 2 }
(finalCtx, trace) <- runThreads sched memtype ref ctx
out <- readRef ref
pure (fromJust out, cSchedState finalCtx, trace)
-- | The context a collection of threads are running in.
data Context n r g = Context
{ cSchedState :: g
, cIdSource :: IdSource
, cThreads :: Threads n r
, cWriteBuf :: WriteBuffer r
, cCaps :: Int
}
-- | Run a collection of threads, until there are no threads left.
--
-- Note: this returns the trace in reverse order, because it's more
-- efficient to prepend to a list than append. As this function isn't
-- exposed to users of the library, this is just an internal gotcha to
-- watch out for.
runThreads :: MonadRef r n => (forall x. s x -> IdSource -> n (Result x, IdSource, TTrace))
-> Scheduler g -> MemType -> g -> Threads n r s -> IdSource -> r (Maybe (Either Failure a)) -> n (g, Trace)
runThreads runstm sched memtype origg origthreads idsrc ref = go idsrc [] Nothing origg origthreads emptyBuffer 2 where
go idSource sofar prior g threads wb caps
| isTerminated = stop g
| isDeadlocked = die g Deadlock
| isSTMLocked = die g STMDeadlock
| isAborted = die g' Abort
| isNonexistant = die g' InternalError
| isBlocked = die g' InternalError
runThreads :: MonadRef r n
=> Scheduler g -> MemType -> r (Maybe (Either Failure a)) -> Context n r g -> n (Context n r g, SeqTrace)
runThreads sched memtype ref = go Seq.empty [] Nothing where
-- sofar is the 'SeqTrace', sofarSched is the @[(Decision,
-- ThreadAction)]@ trace the scheduler needs.
go sofar sofarSched prior ctx
| isTerminated = stop ctx
| isDeadlocked = die Deadlock ctx
| isSTMLocked = die STMDeadlock ctx
| isAborted = die Abort $ ctx { cSchedState = g' }
| isNonexistant = die InternalError $ ctx { cSchedState = g' }
| isBlocked = die InternalError $ ctx { cSchedState = g' }
| otherwise = do
stepped <- stepThread runstm memtype (_continuation $ fromJust thread) idSource chosen threads wb caps
stepped <- stepThread sched memtype chosen (_continuation $ fromJust thread) $ ctx { cSchedState = g' }
case stepped of
Right (threads', idSource', act, wb', caps') -> loop threads' idSource' act wb' caps'
Right (ctx', actOrTrc) -> loop actOrTrc ctx'
Left UncaughtException
| chosen == initialThread -> die g' UncaughtException
| otherwise -> loop (kill chosen threads) idSource Killed wb caps
Left failure -> die g' failure
| chosen == initialThread -> die UncaughtException $ ctx { cSchedState = g' }
| otherwise -> loop (Right Killed) $ ctx { cThreads = kill chosen threadsc, cSchedState = g' }
Left failure -> die failure $ ctx { cSchedState = g' }
where
(choice, g') = sched (map (\(d,_,a) -> (d,a)) $ reverse sofar) ((\p (_,_,a) -> (p,a)) <$> prior <*> listToMaybe sofar) (fromList $ map (\(t,l:|_) -> (t,l)) runnable') g
(choice, g') = sched sofarSched prior (fromList $ map (\(t,l:|_) -> (t,l)) runnable') (cSchedState ctx)
chosen = fromJust choice
runnable' = [(t, nextActions t) | t <- sort $ M.keys runnable]
runnable = M.filter (isNothing . _blocking) threadsc
thread = M.lookup chosen threadsc
threadsc = addCommitThreads wb threads
threadsc = addCommitThreads (cWriteBuf ctx) threads
threads = cThreads ctx
isAborted = isNothing choice
isBlocked = isJust . _blocking $ fromJust thread
isNonexistant = isNothing thread
@ -87,299 +120,262 @@ runThreads runstm sched memtype origg origthreads idsrc ref = go idsrc [] Nothin
_ -> thrd
decision
| Just chosen == prior = Continue
| prior `notElem` map (Just . fst) runnable' = Start chosen
| Just chosen == (fst <$> prior) = Continue
| (fst <$> prior) `notElem` map (Just . fst) runnable' = Start chosen
| otherwise = SwitchTo chosen
nextActions t = lookahead . _continuation . fromJust $ M.lookup t threadsc
stop outg = pure (outg, sofar)
die outg reason = writeRef ref (Just $ Left reason) >> stop outg
stop finalCtx = pure (finalCtx, sofar)
die reason finalCtx = writeRef ref (Just $ Left reason) >> stop finalCtx
loop threads' idSource' act wb' =
let sofar' = ((decision, runnable', act) : sofar)
threads'' = if (interruptible <$> M.lookup chosen threads') /= Just False then unblockWaitingOn chosen threads' else threads'
in go idSource' sofar' (Just chosen) g' (delCommitThreads threads'') wb'
loop trcOrAct ctx' =
let (act, trc) = case trcOrAct of
Left (a, as) -> (a, (decision, runnable', a) <| as)
Right a -> (a, Seq.singleton (decision, runnable', a))
threads' = if (interruptible <$> M.lookup chosen (cThreads ctx')) /= Just False
then unblockWaitingOn chosen (cThreads ctx')
else cThreads ctx'
sofar' = sofar <> trc
sofarSched' = sofarSched <> map (\(d,_,a) -> (d,a)) (F.toList trc)
prior' = Just (chosen, act)
in go sofar' sofarSched' prior' $ ctx' { cThreads = delCommitThreads threads' }
--------------------------------------------------------------------------------
-- * Single-step execution
-- | Run a single thread one step, by dispatching on the type of
-- 'Action'.
stepThread :: forall n r s. MonadRef r n
=> (forall x. s x -> IdSource -> n (Result x, IdSource, TTrace))
-- ^ Run a 'MonadSTM' transaction atomically.
stepThread :: forall n r g. MonadRef r n
=> Scheduler g
-- ^ The scheduler.
-> MemType
-- ^ The memory model
-> Action n r s
-- ^ Action to step
-> IdSource
-- ^ Source of fresh IDs
-- ^ The memory model to use.
-> ThreadId
-- ^ ID of the current thread
-> Threads n r s
-- ^ Current state of threads
-> WriteBuffer r
-- ^ @CRef@ write buffer
-> Int
-- ^ The number of capabilities
-> n (Either Failure (Threads n r s, IdSource, ThreadAction, WriteBuffer r, Int))
stepThread runstm memtype action idSource tid threads wb caps = case action of
AFork n a b -> stepFork n a b
AMyTId c -> stepMyTId c
AGetNumCapabilities c -> stepGetNumCapabilities c
ASetNumCapabilities i c -> stepSetNumCapabilities i c
AYield c -> stepYield c
ANewVar n c -> stepNewVar n c
APutVar var a c -> stepPutVar var a c
ATryPutVar var a c -> stepTryPutVar var a c
AReadVar var c -> stepReadVar var c
ATakeVar var c -> stepTakeVar var c
ATryTakeVar var c -> stepTryTakeVar var c
ANewRef n a c -> stepNewRef n a c
AReadRef ref c -> stepReadRef ref c
AReadRefCas ref c -> stepReadRefCas ref c
AModRef ref f c -> stepModRef ref f c
AModRefCas ref f c -> stepModRefCas ref f c
AWriteRef ref a c -> stepWriteRef ref a c
ACasRef ref tick a c -> stepCasRef ref tick a c
ACommit t c -> stepCommit t c
AAtom stm c -> stepAtom stm c
ALift na -> stepLift na
AThrow e -> stepThrow e
AThrowTo t e c -> stepThrowTo t e c
ACatching h ma c -> stepCatching h ma c
APopCatching a -> stepPopCatching a
AMasking m ma c -> stepMasking m ma c
AResetMask b1 b2 m c -> stepResetMask b1 b2 m c
AReturn c -> stepReturn c
AMessage m c -> stepMessage m c
AStop na -> stepStop na
-> Action n r
-- ^ Action to step
-> Context n r g
-- ^ The execution context.
-> n (Either Failure (Context n r g, Either (ThreadAction, SeqTrace) ThreadAction))
stepThread sched memtype tid action ctx = case action of
-- start a new thread, assigning it the next 'ThreadId'
AFork n a b -> pure . Right $
let threads' = launch tid newtid a (cThreads ctx)
(idSource', newtid) = nextTId n (cIdSource ctx)
in (ctx { cThreads = goto (b newtid) tid threads', cIdSource = idSource' }, Right (Fork newtid))
where
-- | Start a new thread, assigning it the next 'ThreadId'
--
-- Explicit type signature needed for GHC 8. Looks like the
-- impredicative polymorphism checks got stronger.
stepFork :: String
-> ((forall b. M n r s b -> M n r s b) -> Action n r s)
-> (ThreadId -> Action n r s)
-> n (Either Failure (Threads n r s, IdSource, ThreadAction, WriteBuffer r, Int))
stepFork n a b = return $ Right (goto (b newtid) tid threads', idSource', Fork newtid, wb, caps) where
threads' = launch tid newtid a threads
(idSource', newtid) = nextTId n idSource
-- get the 'ThreadId' of the current thread
AMyTId c -> simple (goto (c tid) tid (cThreads ctx)) MyThreadId
-- | Get the 'ThreadId' of the current thread
stepMyTId c = simple (goto (c tid) tid threads) MyThreadId
-- get the number of capabilities
AGetNumCapabilities c -> simple (goto (c (cCaps ctx)) tid (cThreads ctx)) $ GetNumCapabilities (cCaps ctx)
-- | Get the number of capabilities
stepGetNumCapabilities c = simple (goto (c caps) tid threads) $ GetNumCapabilities caps
-- set the number of capabilities
ASetNumCapabilities i c -> pure . Right $
(ctx { cThreads = goto c tid (cThreads ctx), cCaps = i }, Right (SetNumCapabilities i))
-- | Set the number of capabilities
stepSetNumCapabilities i c = return $ Right (goto c tid threads, idSource, SetNumCapabilities i, wb, i)
-- yield the current thread
AYield c -> simple (goto c tid (cThreads ctx)) Yield
-- | Yield the current thread
stepYield c = simple (goto c tid threads) Yield
-- create a new @MVar@, using the next 'MVarId'.
ANewVar n c -> do
let (idSource', newmvid) = nextMVId n (cIdSource ctx)
ref <- newRef Nothing
let mvar = MVar newmvid ref
pure $ Right (ctx { cThreads = goto (c mvar) tid (cThreads ctx), cIdSource = idSource' }, Right (NewVar newmvid))
-- | Put a value into a @MVar@, blocking the thread until it's
-- empty.
stepPutVar cvar@(MVar cvid _) a c = synchronised $ do
(success, threads', woken) <- putIntoMVar cvar a c tid threads
-- put a value into a @MVar@, blocking the thread until it's empty.
APutVar cvar@(MVar cvid _) a c -> synchronised $ do
(success, threads', woken) <- putIntoMVar cvar a c tid (cThreads ctx)
simple threads' $ if success then PutVar cvid woken else BlockedPutVar cvid
-- | Try to put a value into a @MVar@, without blocking.
stepTryPutVar cvar@(MVar cvid _) a c = synchronised $ do
(success, threads', woken) <- tryPutIntoMVar cvar a c tid threads
-- try to put a value into a @MVar@, without blocking.
ATryPutVar cvar@(MVar cvid _) a c -> synchronised $ do
(success, threads', woken) <- tryPutIntoMVar cvar a c tid (cThreads ctx)
simple threads' $ TryPutVar cvid success woken
-- | Get the value from a @MVar@, without emptying, blocking the
-- get the value from a @MVar@, without emptying, blocking the
-- thread until it's full.
stepReadVar cvar@(MVar cvid _) c = synchronised $ do
(success, threads', _) <- readFromMVar cvar c tid threads
AReadVar cvar@(MVar cvid _) c -> synchronised $ do
(success, threads', _) <- readFromMVar cvar c tid (cThreads ctx)
simple threads' $ if success then ReadVar cvid else BlockedReadVar cvid
-- | Take the value from a @MVar@, blocking the thread until it's
-- take the value from a @MVar@, blocking the thread until it's
-- full.
stepTakeVar cvar@(MVar cvid _) c = synchronised $ do
(success, threads', woken) <- takeFromMVar cvar c tid threads
ATakeVar cvar@(MVar cvid _) c -> synchronised $ do
(success, threads', woken) <- takeFromMVar cvar c tid (cThreads ctx)
simple threads' $ if success then TakeVar cvid woken else BlockedTakeVar cvid
-- | Try to take the value from a @MVar@, without blocking.
stepTryTakeVar cvar@(MVar cvid _) c = synchronised $ do
(success, threads', woken) <- tryTakeFromMVar cvar c tid threads
-- try to take the value from a @MVar@, without blocking.
ATryTakeVar cvar@(MVar cvid _) c -> synchronised $ do
(success, threads', woken) <- tryTakeFromMVar cvar c tid (cThreads ctx)
simple threads' $ TryTakeVar cvid success woken
-- | Read from a @CRef@.
stepReadRef cref@(CRef crid _) c = do
-- create a new @CRef@, using the next 'CRefId'.
ANewRef n a c -> do
let (idSource', newcrid) = nextCRId n (cIdSource ctx)
ref <- newRef (M.empty, 0, a)
let cref = CRef newcrid ref
pure $ Right (ctx { cThreads = goto (c cref) tid (cThreads ctx), cIdSource = idSource' }, Right (NewRef newcrid))
-- read from a @CRef@.
AReadRef cref@(CRef crid _) c -> do
val <- readCRef cref tid
simple (goto (c val) tid threads) $ ReadRef crid
simple (goto (c val) tid (cThreads ctx)) $ ReadRef crid
-- | Read from a @CRef@ for future compare-and-swap operations.
stepReadRefCas cref@(CRef crid _) c = do
-- read from a @CRef@ for future compare-and-swap operations.
AReadRefCas cref@(CRef crid _) c -> do
tick <- readForTicket cref tid
simple (goto (c tick) tid threads) $ ReadRefCas crid
simple (goto (c tick) tid (cThreads ctx)) $ ReadRefCas crid
-- | Modify a @CRef@.
stepModRef cref@(CRef crid _) f c = synchronised $ do
-- modify a @CRef@.
AModRef cref@(CRef crid _) f c -> synchronised $ do
(new, val) <- f <$> readCRef cref tid
writeImmediate cref new
simple (goto (c val) tid threads) $ ModRef crid
simple (goto (c val) tid (cThreads ctx)) $ ModRef crid
-- | Modify a @CRef@ using a compare-and-swap.
stepModRefCas cref@(CRef crid _) f c = synchronised $ do
-- modify a @CRef@ using a compare-and-swap.
AModRefCas cref@(CRef crid _) f c -> synchronised $ do
tick@(Ticket _ _ old) <- readForTicket cref tid
let (new, val) = f old
void $ casCRef cref tid tick new
simple (goto (c val) tid threads) $ ModRefCas crid
simple (goto (c val) tid (cThreads ctx)) $ ModRefCas crid
-- | Write to a @CRef@ without synchronising
stepWriteRef cref@(CRef crid _) a c = case memtype of
-- Write immediately.
-- write to a @CRef@ without synchronising.
AWriteRef cref@(CRef crid _) a c -> case memtype of
-- write immediately.
SequentialConsistency -> do
writeImmediate cref a
simple (goto c tid threads) $ WriteRef crid
-- Add to buffer using thread id.
simple (goto c tid (cThreads ctx)) $ WriteRef crid
-- add to buffer using thread id.
TotalStoreOrder -> do
wb' <- bufferWrite wb (tid, Nothing) cref a
return $ Right (goto c tid threads, idSource, WriteRef crid, wb', caps)
-- Add to buffer using both thread id and cref id
wb' <- bufferWrite (cWriteBuf ctx) (tid, Nothing) cref a
pure $ Right (ctx { cThreads = goto c tid (cThreads ctx), cWriteBuf = wb' }, Right (WriteRef crid))
-- add to buffer using both thread id and cref id
PartialStoreOrder -> do
wb' <- bufferWrite wb (tid, Just crid) cref a
return $ Right (goto c tid threads, idSource, WriteRef crid, wb', caps)
wb' <- bufferWrite (cWriteBuf ctx) (tid, Just crid) cref a
pure $ Right (ctx { cThreads = goto c tid (cThreads ctx), cWriteBuf = wb' }, Right (WriteRef crid))
-- | Perform a compare-and-swap on a @CRef@.
stepCasRef cref@(CRef crid _) tick a c = synchronised $ do
-- perform a compare-and-swap on a @CRef@.
ACasRef cref@(CRef crid _) tick a c -> synchronised $ do
(suc, tick') <- casCRef cref tid tick a
simple (goto (c (suc, tick')) tid threads) $ CasRef crid suc
simple (goto (c (suc, tick')) tid (cThreads ctx)) $ CasRef crid suc
-- | Commit a @CRef@ write
stepCommit t c = do
-- commit a @CRef@ write
ACommit t c -> do
wb' <- case memtype of
-- Shouldn't ever get here
-- shouldn't ever get here
SequentialConsistency ->
error "Attempting to commit under SequentialConsistency"
-- commit using the thread id.
TotalStoreOrder -> commitWrite (cWriteBuf ctx) (t, Nothing)
-- commit using the cref id.
PartialStoreOrder -> commitWrite (cWriteBuf ctx) (t, Just c)
pure $ Right (ctx { cWriteBuf = wb' }, Right (CommitRef t c))
-- Commit using the thread id.
TotalStoreOrder -> commitWrite wb (t, Nothing)
-- Commit using the cref id.
PartialStoreOrder -> commitWrite wb (t, Just c)
return $ Right (threads, idSource, CommitRef t c, wb', caps)
-- | Run a STM transaction atomically.
stepAtom stm c = synchronised $ do
(res, idSource', trace) <- runstm stm idSource
-- run a STM transaction atomically.
AAtom stm c -> synchronised $ do
(res, idSource', trace) <- runTransaction stm (cIdSource ctx)
case res of
Success _ written val ->
let (threads', woken) = wake (OnTVar written) threads
in return $ Right (goto (c val) tid threads', idSource', STM trace woken, wb, caps)
let (threads', woken) = wake (OnTVar written) (cThreads ctx)
in pure $ Right (ctx { cThreads = goto (c val) tid threads', cIdSource = idSource' }, Right (STM trace woken))
Retry touched ->
let threads' = block (OnTVar touched) tid threads
in return $ Right (threads', idSource', BlockedSTM trace, wb, caps)
let threads' = block (OnTVar touched) tid (cThreads ctx)
in pure $ Right (ctx { cThreads = threads', cIdSource = idSource'}, Right (BlockedSTM trace))
Exception e -> do
res' <- stepThrow e
return $ case res' of
Right (threads', _, _, _, _) -> Right (threads', idSource', Throw, wb, caps)
pure $ case res' of
Right (ctx', _) -> Right (ctx' { cIdSource = idSource' }, Right Throw)
Left err -> Left err
-- | Run a subcomputation in an exception-catching context.
stepCatching h ma c = simple threads' Catching where
a = runCont ma (APopCatching . c)
e exc = runCont (h exc) (APopCatching . c)
-- lift an action from the underlying monad into the @Conc@
-- computation.
ALift na -> do
a <- na
simple (goto a tid (cThreads ctx)) LiftIO
threads' = goto a tid (catching e tid threads)
-- | Pop the top exception handler from the thread's stack.
stepPopCatching a = simple threads' PopCatching where
threads' = goto a tid (uncatching tid threads)
-- | Throw an exception, and propagate it to the appropriate
-- throw an exception, and propagate it to the appropriate
-- handler.
stepThrow e =
case propagate (toException e) tid threads of
Just threads' -> simple threads' Throw
Nothing -> return $ Left UncaughtException
AThrow e -> stepThrow e
-- | Throw an exception to the target thread, and propagate it to
-- throw an exception to the target thread, and propagate it to
-- the appropriate handler.
stepThrowTo t e c = synchronised $
let threads' = goto c tid threads
blocked = block (OnMask t) tid threads
in case M.lookup t threads of
AThrowTo t e c -> synchronised $
let threads' = goto c tid (cThreads ctx)
blocked = block (OnMask t) tid (cThreads ctx)
in case M.lookup t (cThreads ctx) of
Just thread
| interruptible thread -> case propagate (toException e) t threads' of
Just threads'' -> simple threads'' $ ThrowTo t
Nothing
| t == initialThread -> return $ Left UncaughtException
| t == initialThread -> pure $ Left UncaughtException
| otherwise -> simple (kill t threads') $ ThrowTo t
| otherwise -> simple blocked $ BlockedThrowTo t
Nothing -> simple threads' $ ThrowTo t
-- | Execute a subcomputation with a new masking state, and give
-- it a function to run a computation with the current masking
-- state.
--
-- Explicit type sig necessary for checking in the prescence of
-- 'umask', sadly.
stepMasking :: MaskingState
-> ((forall b. M n r s b -> M n r s b) -> M n r s a)
-> (a -> Action n r s)
-> n (Either Failure (Threads n r s, IdSource, ThreadAction, WriteBuffer r, Int))
stepMasking m ma c = simple threads' $ SetMasking False m where
a = runCont (ma umask) (AResetMask False False m' . c)
-- run a subcomputation in an exception-catching context.
ACatching h ma c ->
let a = runCont ma (APopCatching . c)
e exc = runCont (h exc) (APopCatching . c)
threads' = goto a tid (catching e tid (cThreads ctx))
in simple threads' Catching
m' = _masking . fromJust $ M.lookup tid threads
umask mb = resetMask True m' >> mb >>= \b -> resetMask False m >> return b
-- pop the top exception handler from the thread's stack.
APopCatching a ->
let threads' = goto a tid (uncatching tid (cThreads ctx))
in simple threads' PopCatching
-- execute a subcomputation with a new masking state, and give it
-- a function to run a computation with the current masking state.
AMasking m ma c ->
let a = runCont (ma umask) (AResetMask False False m' . c)
m' = _masking . fromJust $ M.lookup tid (cThreads ctx)
umask mb = resetMask True m' >> mb >>= \b -> resetMask False m >> pure b
resetMask typ ms = cont $ \k -> AResetMask typ True ms $ k ()
threads' = goto a tid (mask m tid (cThreads ctx))
in simple threads' $ SetMasking False m
threads' = goto a tid (mask m tid threads)
-- | Reset the masking thread of the state.
stepResetMask b1 b2 m c = simple threads' act where
act = (if b1 then SetMasking else ResetMasking) b2 m
threads' = goto c tid (mask m tid threads)
-- reset the masking thread of the state.
AResetMask b1 b2 m c ->
let act = (if b1 then SetMasking else ResetMasking) b2 m
threads' = goto c tid (mask m tid (cThreads ctx))
in simple threads' act
-- | Create a new @MVar@, using the next 'MVarId'.
stepNewVar n c = do
let (idSource', newmvid) = nextMVId n idSource
ref <- newRef Nothing
let mvar = MVar newmvid ref
return $ Right (goto (c mvar) tid threads, idSource', NewVar newmvid, wb, caps)
-- execute a 'return' or 'pure'.
AReturn c -> simple (goto c tid (cThreads ctx)) Return
-- | Create a new @CRef@, using the next 'CRefId'.
stepNewRef n a c = do
let (idSource', newcrid) = nextCRId n idSource
ref <- newRef (M.empty, 0, a)
let cref = CRef newcrid ref
return $ Right (goto (c cref) tid threads, idSource', NewRef newcrid, wb, caps)
-- add a message to the trace.
AMessage m c -> simple (goto c tid (cThreads ctx)) (Message m)
-- | Lift an action from the underlying monad into the @Conc@
-- computation.
stepLift na = do
a <- na
simple (goto a tid threads) LiftIO
-- kill the current thread.
AStop na -> na >> simple (kill tid (cThreads ctx)) Stop
-- | Execute a 'return' or 'pure'.
stepReturn c = simple (goto c tid threads) Return
-- run a subconcurrent computation.
ASub ma c
| M.size (cThreads ctx) > 1 -> pure (Left IllegalSubconcurrency)
| otherwise -> do
(res, g', trace) <- runConcurrency sched memtype (cSchedState ctx) ma
pure $ Right (ctx { cThreads = goto (c res) tid (cThreads ctx), cSchedState = g' }, Left (Subconcurrency, trace))
where
-- | Add a message to the trace.
stepMessage m c = simple (goto c tid threads) (Message m)
-- this is not inline in the long @case@ above as it's needed by
-- @AAtom@, @AThrow@, and @AThrowTo@.
stepThrow e =
case propagate (toException e) tid (cThreads ctx) of
Just threads' -> simple threads' Throw
Nothing -> pure $ Left UncaughtException
-- | Kill the current thread.
stepStop na = na >> simple (kill tid threads) Stop
-- helper for actions which only change the threads.
simple threads' act = pure $ Right (ctx { cThreads = threads' }, Right act)
-- | Helper for actions which don't touch the 'IdSource' or
-- 'WriteBuffer'
simple threads' act = return $ Right (threads', idSource, act, wb, caps)
-- | Helper for actions impose a write barrier.
-- helper for actions impose a write barrier.
synchronised ma = do
writeBarrier wb
writeBarrier (cWriteBuf ctx)
res <- ma
return $ case res of
Right (threads', idSource', act', _, caps') -> Right (threads', idSource', act', emptyBuffer, caps')
Right (ctx', act) -> Right (ctx' { cWriteBuf = emptyBuffer }, act)
_ -> res

View File

@ -18,6 +18,7 @@ import Data.Dynamic (Dynamic)
import Data.Map.Strict (Map)
import Data.List.NonEmpty (NonEmpty, fromList)
import Test.DejaFu.Common
import Test.DejaFu.STM (STMLike)
{-# ANN module ("HLint: ignore Use record patterns" :: String) #-}
@ -32,16 +33,16 @@ import Test.DejaFu.Common
-- current expression of threads and exception handlers very difficult
-- (perhaps even not possible without significant reworking), so I
-- abandoned the attempt.
newtype M n r s a = M { runM :: (a -> Action n r s) -> Action n r s }
newtype M n r a = M { runM :: (a -> Action n r) -> Action n r }
instance Functor (M n r s) where
instance Functor (M n r) where
fmap f m = M $ \ c -> runM m (c . f)
instance Applicative (M n r s) where
instance Applicative (M n r) where
pure x = M $ \c -> AReturn $ c x
f <*> v = M $ \c -> runM f (\g -> runM v (c . g))
instance Monad (M n r s) where
instance Monad (M n r) where
return = pure
m >>= k = M $ \c -> runM m (\x -> runM (k x) c)
@ -82,11 +83,11 @@ data Ticket a = Ticket
}
-- | Construct a continuation-passing operation from a function.
cont :: ((a -> Action n r s) -> Action n r s) -> M n r s a
cont :: ((a -> Action n r) -> Action n r) -> M n r a
cont = M
-- | Run a CPS computation with the given final computation.
runCont :: M n r s a -> (a -> Action n r s) -> Action n r s
runCont :: M n r a -> (a -> Action n r) -> Action n r
runCont = runM
--------------------------------------------------------------------------------
@ -96,49 +97,51 @@ runCont = runM
-- only occur as a result of an action, and they cover (most of) the
-- primitives of the concurrency. 'spawn' is absent as it is
-- implemented in terms of 'newEmptyMVar', 'fork', and 'putMVar'.
data Action n r s =
AFork String ((forall b. M n r s b -> M n r s b) -> Action n r s) (ThreadId -> Action n r s)
| AMyTId (ThreadId -> Action n r s)
data Action n r =
AFork String ((forall b. M n r b -> M n r b) -> Action n r) (ThreadId -> Action n r)
| AMyTId (ThreadId -> Action n r)
| AGetNumCapabilities (Int -> Action n r s)
| ASetNumCapabilities Int (Action n r s)
| AGetNumCapabilities (Int -> Action n r)
| ASetNumCapabilities Int (Action n r)
| forall a. ANewVar String (MVar r a -> Action n r s)
| forall a. APutVar (MVar r a) a (Action n r s)
| forall a. ATryPutVar (MVar r a) a (Bool -> Action n r s)
| forall a. AReadVar (MVar r a) (a -> Action n r s)
| forall a. ATakeVar (MVar r a) (a -> Action n r s)
| forall a. ATryTakeVar (MVar r a) (Maybe a -> Action n r s)
| forall a. ANewVar String (MVar r a -> Action n r)
| forall a. APutVar (MVar r a) a (Action n r)
| forall a. ATryPutVar (MVar r a) a (Bool -> Action n r)
| forall a. AReadVar (MVar r a) (a -> Action n r)
| forall a. ATakeVar (MVar r a) (a -> Action n r)
| forall a. ATryTakeVar (MVar r a) (Maybe a -> Action n r)
| forall a. ANewRef String a (CRef r a -> Action n r s)
| forall a. AReadRef (CRef r a) (a -> Action n r s)
| forall a. AReadRefCas (CRef r a) (Ticket a -> Action n r s)
| forall a b. AModRef (CRef r a) (a -> (a, b)) (b -> Action n r s)
| forall a b. AModRefCas (CRef r a) (a -> (a, b)) (b -> Action n r s)
| forall a. AWriteRef (CRef r a) a (Action n r s)
| forall a. ACasRef (CRef r a) (Ticket a) a ((Bool, Ticket a) -> Action n r s)
| forall a. ANewRef String a (CRef r a -> Action n r)
| forall a. AReadRef (CRef r a) (a -> Action n r)
| forall a. AReadRefCas (CRef r a) (Ticket a -> Action n r)
| forall a b. AModRef (CRef r a) (a -> (a, b)) (b -> Action n r)
| forall a b. AModRefCas (CRef r a) (a -> (a, b)) (b -> Action n r)
| forall a. AWriteRef (CRef r a) a (Action n r)
| forall a. ACasRef (CRef r a) (Ticket a) a ((Bool, Ticket a) -> Action n r)
| forall e. Exception e => AThrow e
| forall e. Exception e => AThrowTo ThreadId e (Action n r s)
| forall a e. Exception e => ACatching (e -> M n r s a) (M n r s a) (a -> Action n r s)
| APopCatching (Action n r s)
| forall a. AMasking MaskingState ((forall b. M n r s b -> M n r s b) -> M n r s a) (a -> Action n r s)
| AResetMask Bool Bool MaskingState (Action n r s)
| forall e. Exception e => AThrowTo ThreadId e (Action n r)
| forall a e. Exception e => ACatching (e -> M n r a) (M n r a) (a -> Action n r)
| APopCatching (Action n r)
| forall a. AMasking MaskingState ((forall b. M n r b -> M n r b) -> M n r a) (a -> Action n r)
| AResetMask Bool Bool MaskingState (Action n r)
| AMessage Dynamic (Action n r s)
| AMessage Dynamic (Action n r)
| forall a. AAtom (s a) (a -> Action n r s)
| ALift (n (Action n r s))
| AYield (Action n r s)
| AReturn (Action n r s)
| forall a. AAtom (STMLike n r a) (a -> Action n r)
| ALift (n (Action n r))
| AYield (Action n r)
| AReturn (Action n r)
| ACommit ThreadId CRefId
| AStop (n ())
| forall a. ASub (M n r a) (Either Failure a -> Action n r)
--------------------------------------------------------------------------------
-- * Scheduling & Traces
-- | Look as far ahead in the given continuation as possible.
lookahead :: Action n r s -> NonEmpty Lookahead
lookahead :: Action n r -> NonEmpty Lookahead
lookahead = fromList . lookahead' where
lookahead' (AFork _ _ _) = [WillFork]
lookahead' (AMyTId _) = [WillMyThreadId]
@ -170,3 +173,4 @@ lookahead = fromList . lookahead' where
lookahead' (AYield k) = WillYield : lookahead' k
lookahead' (AReturn k) = WillReturn : lookahead' k
lookahead' (AStop _) = [WillStop]
lookahead' (ASub _ _) = [WillSubconcurrency]

View File

@ -126,7 +126,7 @@ writeBarrier (WriteBuffer wb) = mapM_ flush $ M.elems wb where
flush = mapM_ $ \(BufferedWrite _ cref a) -> writeImmediate cref a
-- | Add phantom threads to the thread list to commit pending writes.
addCommitThreads :: WriteBuffer r -> Threads n r s -> Threads n r s
addCommitThreads :: WriteBuffer r -> Threads n r -> Threads n r
addCommitThreads (WriteBuffer wb) ts = ts <> M.fromList phantoms where
phantoms = [ (ThreadId Nothing $ negate tid, mkthread $ fromJust c)
| ((k, b), tid) <- zip (M.toList wb) [1..]
@ -136,41 +136,41 @@ addCommitThreads (WriteBuffer wb) ts = ts <> M.fromList phantoms where
go EmptyL = Nothing
-- | Remove phantom threads.
delCommitThreads :: Threads n r s -> Threads n r s
delCommitThreads :: Threads n r -> Threads n r
delCommitThreads = M.filterWithKey $ \k _ -> k >= initialThread
--------------------------------------------------------------------------------
-- * Manipulating @MVar@s
-- | Put into a @MVar@, blocking if full.
putIntoMVar :: MonadRef r n => MVar r a -> a -> Action n r s
-> ThreadId -> Threads n r s -> n (Bool, Threads n r s, [ThreadId])
putIntoMVar :: MonadRef r n => MVar r a -> a -> Action n r
-> ThreadId -> Threads n r -> n (Bool, Threads n r, [ThreadId])
putIntoMVar cvar a c = mutMVar True cvar a (const c)
-- | Try to put into a @MVar@, not blocking if full.
tryPutIntoMVar :: MonadRef r n => MVar r a -> a -> (Bool -> Action n r s)
-> ThreadId -> Threads n r s -> n (Bool, Threads n r s, [ThreadId])
tryPutIntoMVar :: MonadRef r n => MVar r a -> a -> (Bool -> Action n r)
-> ThreadId -> Threads n r -> n (Bool, Threads n r, [ThreadId])
tryPutIntoMVar = mutMVar False
-- | Read from a @MVar@, blocking if empty.
readFromMVar :: MonadRef r n => MVar r a -> (a -> Action n r s)
-> ThreadId -> Threads n r s -> n (Bool, Threads n r s, [ThreadId])
readFromMVar :: MonadRef r n => MVar r a -> (a -> Action n r)
-> ThreadId -> Threads n r -> n (Bool, Threads n r, [ThreadId])
readFromMVar cvar c = seeMVar False True cvar (c . fromJust)
-- | Take from a @MVar@, blocking if empty.
takeFromMVar :: MonadRef r n => MVar r a -> (a -> Action n r s)
-> ThreadId -> Threads n r s -> n (Bool, Threads n r s, [ThreadId])
takeFromMVar :: MonadRef r n => MVar r a -> (a -> Action n r)
-> ThreadId -> Threads n r -> n (Bool, Threads n r, [ThreadId])
takeFromMVar cvar c = seeMVar True True cvar (c . fromJust)
-- | Try to take from a @MVar@, not blocking if empty.
tryTakeFromMVar :: MonadRef r n => MVar r a -> (Maybe a -> Action n r s)
-> ThreadId -> Threads n r s -> n (Bool, Threads n r s, [ThreadId])
tryTakeFromMVar :: MonadRef r n => MVar r a -> (Maybe a -> Action n r)
-> ThreadId -> Threads n r -> n (Bool, Threads n r, [ThreadId])
tryTakeFromMVar = seeMVar True False
-- | Mutate a @MVar@, in either a blocking or nonblocking way.
mutMVar :: MonadRef r n
=> Bool -> MVar r a -> a -> (Bool -> Action n r s)
-> ThreadId -> Threads n r s -> n (Bool, Threads n r s, [ThreadId])
=> Bool -> MVar r a -> a -> (Bool -> Action n r)
-> ThreadId -> Threads n r -> n (Bool, Threads n r, [ThreadId])
mutMVar blocking (MVar cvid ref) a c threadid threads = do
val <- readRef ref
@ -191,8 +191,8 @@ mutMVar blocking (MVar cvid ref) a c threadid threads = do
-- | Read a @MVar@, in either a blocking or nonblocking
-- way.
seeMVar :: MonadRef r n
=> Bool -> Bool -> MVar r a -> (Maybe a -> Action n r s)
-> ThreadId -> Threads n r s -> n (Bool, Threads n r s, [ThreadId])
=> Bool -> Bool -> MVar r a -> (Maybe a -> Action n r)
-> ThreadId -> Threads n r -> n (Bool, Threads n r, [ThreadId])
seeMVar emptying blocking (MVar cvid ref) c threadid threads = do
val <- readRef ref

View File

@ -27,22 +27,22 @@ import qualified Data.Map.Strict as M
-- * Threads
-- | Threads are stored in a map index by 'ThreadId'.
type Threads n r s = Map ThreadId (Thread n r s)
type Threads n r = Map ThreadId (Thread n r)
-- | All the state of a thread.
data Thread n r s = Thread
{ _continuation :: Action n r s
data Thread n r = Thread
{ _continuation :: Action n r
-- ^ The next action to execute.
, _blocking :: Maybe BlockedOn
-- ^ The state of any blocks.
, _handlers :: [Handler n r s]
, _handlers :: [Handler n r]
-- ^ Stack of exception handlers
, _masking :: MaskingState
-- ^ The exception masking state.
}
-- | Construct a thread with just one action
mkthread :: Action n r s -> Thread n r s
mkthread :: Action n r -> Thread n r
mkthread c = Thread c Nothing [] Unmasked
--------------------------------------------------------------------------------
@ -53,7 +53,7 @@ mkthread c = Thread c Nothing [] Unmasked
data BlockedOn = OnMVarFull MVarId | OnMVarEmpty MVarId | OnTVar [TVarId] | OnMask ThreadId deriving Eq
-- | Determine if a thread is blocked in a certain way.
(~=) :: Thread n r s -> BlockedOn -> Bool
(~=) :: Thread n r -> BlockedOn -> Bool
thread ~= theblock = case (_blocking thread, theblock) of
(Just (OnMVarFull _), OnMVarFull _) -> True
(Just (OnMVarEmpty _), OnMVarEmpty _) -> True
@ -65,11 +65,11 @@ thread ~= theblock = case (_blocking thread, theblock) of
-- * Exceptions
-- | An exception handler.
data Handler n r s = forall e. Exception e => Handler (e -> Action n r s)
data Handler n r = forall e. Exception e => Handler (e -> Action n r)
-- | Propagate an exception upwards, finding the closest handler
-- which can deal with it.
propagate :: SomeException -> ThreadId -> Threads n r s -> Maybe (Threads n r s)
propagate :: SomeException -> ThreadId -> Threads n r -> Maybe (Threads n r)
propagate e tid threads = case M.lookup tid threads >>= go . _handlers of
Just (act, hs) -> Just $ except act hs tid threads
Nothing -> Nothing
@ -79,40 +79,40 @@ propagate e tid threads = case M.lookup tid threads >>= go . _handlers of
go (Handler h:hs) = maybe (go hs) (\act -> Just (act, hs)) $ h <$> fromException e
-- | Check if a thread can be interrupted by an exception.
interruptible :: Thread n r s -> Bool
interruptible :: Thread n r -> Bool
interruptible thread = _masking thread == Unmasked || (_masking thread == MaskedInterruptible && isJust (_blocking thread))
-- | Register a new exception handler.
catching :: Exception e => (e -> Action n r s) -> ThreadId -> Threads n r s -> Threads n r s
catching :: Exception e => (e -> Action n r) -> ThreadId -> Threads n r -> Threads n r
catching h = M.alter $ \(Just thread) -> Just $ thread { _handlers = Handler h : _handlers thread }
-- | Remove the most recent exception handler.
uncatching :: ThreadId -> Threads n r s -> Threads n r s
uncatching :: ThreadId -> Threads n r -> Threads n r
uncatching = M.alter $ \(Just thread) -> Just $ thread { _handlers = tail $ _handlers thread }
-- | Raise an exception in a thread.
except :: Action n r s -> [Handler n r s] -> ThreadId -> Threads n r s -> Threads n r s
except :: Action n r -> [Handler n r] -> ThreadId -> Threads n r -> Threads n r
except act hs = M.alter $ \(Just thread) -> Just $ thread { _continuation = act, _handlers = hs, _blocking = Nothing }
-- | Set the masking state of a thread.
mask :: MaskingState -> ThreadId -> Threads n r s -> Threads n r s
mask :: MaskingState -> ThreadId -> Threads n r -> Threads n r
mask ms = M.alter $ \(Just thread) -> Just $ thread { _masking = ms }
--------------------------------------------------------------------------------
-- * Manipulating threads
-- | Replace the @Action@ of a thread.
goto :: Action n r s -> ThreadId -> Threads n r s -> Threads n r s
goto :: Action n r -> ThreadId -> Threads n r -> Threads n r
goto a = M.alter $ \(Just thread) -> Just (thread { _continuation = a })
-- | Start a thread with the given ID, inheriting the masking state
-- from the parent thread. This ID must not already be in use!
launch :: ThreadId -> ThreadId -> ((forall b. M n r s b -> M n r s b) -> Action n r s) -> Threads n r s -> Threads n r s
launch :: ThreadId -> ThreadId -> ((forall b. M n r b -> M n r b) -> Action n r) -> Threads n r -> Threads n r
launch parent tid a threads = launch' ms tid a threads where
ms = fromMaybe Unmasked $ _masking <$> M.lookup parent threads
-- | Start a thread with the given ID and masking state. This must not already be in use!
launch' :: MaskingState -> ThreadId -> ((forall b. M n r s b -> M n r s b) -> Action n r s) -> Threads n r s -> Threads n r s
launch' :: MaskingState -> ThreadId -> ((forall b. M n r b -> M n r b) -> Action n r) -> Threads n r -> Threads n r
launch' ms tid a = M.insert tid thread where
thread = Thread { _continuation = a umask, _blocking = Nothing, _handlers = [], _masking = ms }
@ -120,11 +120,11 @@ launch' ms tid a = M.insert tid thread where
resetMask typ m = cont $ \k -> AResetMask typ True m $ k ()
-- | Kill a thread.
kill :: ThreadId -> Threads n r s -> Threads n r s
kill :: ThreadId -> Threads n r -> Threads n r
kill = M.delete
-- | Block a thread.
block :: BlockedOn -> ThreadId -> Threads n r s -> Threads n r s
block :: BlockedOn -> ThreadId -> Threads n r -> Threads n r
block blockedOn = M.alter doBlock where
doBlock (Just thread) = Just $ thread { _blocking = Just blockedOn }
doBlock _ = error "Invariant failure in 'block': thread does NOT exist!"
@ -132,7 +132,7 @@ block blockedOn = M.alter doBlock where
-- | Unblock all threads waiting on the appropriate block. For 'TVar'
-- blocks, this will wake all threads waiting on at least one of the
-- given 'TVar's.
wake :: BlockedOn -> Threads n r s -> (Threads n r s, [ThreadId])
wake :: BlockedOn -> Threads n r -> (Threads n r, [ThreadId])
wake blockedOn threads = (unblock <$> threads, M.keys $ M.filter isBlocked threads) where
unblock thread
| isBlocked thread = thread { _blocking = Nothing }

View File

@ -79,9 +79,8 @@ module Test.DejaFu.SCT
, sctLengthBound
) where
import Control.DeepSeq (NFData(..))
import Control.Monad.Ref (MonadRef)
import Data.List (nub)
import Data.List (foldl')
import qualified Data.Map.Strict as M
import Data.Maybe (isJust, fromJust)
import qualified Data.Set as S
@ -140,17 +139,16 @@ cBound (Bounds pb fb lb) =
--
-- If no bounds are enabled, just backtrack to the given point.
cBacktrack :: Bounds -> BacktrackFunc
cBacktrack (Bounds Nothing Nothing Nothing) bs i t = backtrackAt (const False) False bs i t
cBacktrack (Bounds pb fb lb) bs i t = lBack . fBack $ pBack bs where
pBack backs = if isJust pb then pBacktrack backs i t else backs
fBack backs = if isJust fb then fBacktrack backs i t else backs
lBack backs = if isJust lb then lBacktrack backs i t else backs
cBacktrack (Bounds (Just _) _ _) = pBacktrack
cBacktrack (Bounds _ (Just _) _) = fBacktrack
cBacktrack (Bounds _ _ (Just _)) = lBacktrack
cBacktrack _ = backtrackAt (\_ _ -> False)
-------------------------------------------------------------------------------
-- Pre-emption bounding
newtype PreemptionBound = PreemptionBound Int
deriving (NFData, Enum, Eq, Ord, Num, Real, Integral, Read, Show)
deriving (Enum, Eq, Ord, Num, Real, Integral, Read, Show)
-- | A sensible default preemption bound: 2.
--
@ -181,31 +179,26 @@ pBound (PreemptionBound pb) ts dl = preEmpCount ts dl <= pb
-- the same state being reached multiple times, but is needed because
-- of the artificial dependency imposed by the bound.
pBacktrack :: BacktrackFunc
pBacktrack bs i tid =
maybe id (\j' b -> backtrack True b j' tid) j $ backtrack False bs i tid
pBacktrack bs = backtrackAt (\_ _ -> False) bs . concatMap addConservative where
addConservative o@(i, _, tid) = o : case conservative i of
Just j -> [(j, True, tid)]
Nothing -> []
where
-- Index of the conservative point
j = goJ . reverse . pairs $ zip [0..i-1] bs where
goJ (((_,b1), (j',b2)):rest)
-- index of conservative point
conservative i = go (reverse (take (i-1) bs)) (i-1) where
go _ (-1) = Nothing
go (b1:rest@(b2:_)) j
| bcktThreadid b1 /= bcktThreadid b2
&& not (isCommitRef . snd $ bcktDecision b1)
&& not (isCommitRef . snd $ bcktDecision b2) = Just j'
| otherwise = goJ rest
goJ [] = Nothing
-- List of adjacent pairs
{-# INLINE pairs #-}
pairs = zip <*> tail
-- Add a backtracking point.
backtrack = backtrackAt $ const False
&& not (isCommitRef $ bcktAction b1)
&& not (isCommitRef $ bcktAction b2) = Just j
| otherwise = go rest (j-1)
go _ _ = Nothing
-------------------------------------------------------------------------------
-- Fair bounding
newtype FairBound = FairBound Int
deriving (NFData, Enum, Eq, Ord, Num, Real, Integral, Read, Show)
deriving (Enum, Eq, Ord, Num, Real, Integral, Read, Show)
-- | A sensible default fair bound: 5.
--
@ -233,15 +226,15 @@ fBound (FairBound fb) ts (_, l) = maxYieldCountDiff ts l <= fb
-- | Add a backtrack point. If the thread isn't runnable, or performs
-- a release operation, add all runnable threads.
fBacktrack :: BacktrackFunc
fBacktrack bs i t = backtrackAt check False bs i t where
fBacktrack = backtrackAt check where
-- True if a release operation is performed.
check b = Just True == (willRelease <$> M.lookup t (bcktRunnable b))
check t b = Just True == (willRelease <$> M.lookup t (bcktRunnable b))
-------------------------------------------------------------------------------
-- Length bounding
newtype LengthBound = LengthBound Int
deriving (NFData, Enum, Eq, Ord, Num, Real, Integral, Read, Show)
deriving (Enum, Eq, Ord, Num, Real, Integral, Read, Show)
-- | A sensible default length bound: 250.
--
@ -269,7 +262,7 @@ lBound (LengthBound lb) ts _ = length ts < lb
-- | Add a backtrack point. If the thread isn't runnable, add all
-- runnable threads.
lBacktrack :: BacktrackFunc
lBacktrack = backtrackAt (const False) False
lBacktrack = backtrackAt (\_ _ -> False)
-------------------------------------------------------------------------------
-- DPOR
@ -313,7 +306,7 @@ sctBounded memtype bf backtrack conc = go initialState where
if schedIgnore s
then go newDPOR
else ((res, trace):) <$> go (pruneCommits $ addBacktracks bpoints newDPOR)
else ((res, trace):) <$> go (addBacktracks bpoints newDPOR)
Nothing -> pure []
@ -332,32 +325,11 @@ sctBounded memtype bf backtrack conc = go initialState where
-- Incorporate the new backtracking steps into the DPOR tree.
addBacktracks = incorporateBacktrackSteps bf
-------------------------------------------------------------------------------
-- Post-processing
-- | Remove commits from the todo sets where every other action will
-- result in a write barrier (and so a commit) occurring.
--
-- To get the benefit from this, do not execute commit actions from
-- the todo set until there are no other choises.
pruneCommits :: DPOR -> DPOR
pruneCommits bpor
| not onlycommits || not alldonesync = go bpor
| otherwise = go bpor { dporTodo = M.empty }
where
go b = b { dporDone = pruneCommits <$> dporDone bpor }
onlycommits = all (<initialThread) . M.keys $ dporTodo bpor
alldonesync = all barrier . M.elems $ dporDone bpor
barrier = isBarrier . simplifyAction . fromJust . dporAction
-------------------------------------------------------------------------------
-- Dependency function
-- | Check if an action is dependent on another.
dependent :: MemType -> DepState -> (ThreadId, ThreadAction) -> (ThreadId, ThreadAction) -> Bool
dependent :: MemType -> DepState -> ThreadId -> ThreadAction -> ThreadId -> ThreadAction -> Bool
-- This is basically the same as 'dependent'', but can make use of the
-- additional information in a 'ThreadAction' to make different
-- decisions in a few cases:
@ -381,14 +353,14 @@ dependent :: MemType -> DepState -> (ThreadId, ThreadAction) -> (ThreadId, Threa
-- - Dependency of STM transactions can be /greatly/ improved here,
-- as the 'Lookahead' does not know which @TVar@s will be touched,
-- and so has to assume all transactions are dependent.
dependent _ _ (_, SetNumCapabilities a) (_, GetNumCapabilities b) = a /= b
dependent _ ds (_, ThrowTo t) (t2, a) = t == t2 && canInterrupt ds t2 a
dependent memtype ds (t1, a1) (t2, a2) = case rewind a2 of
dependent _ _ _ (SetNumCapabilities a) _ (GetNumCapabilities b) = a /= b
dependent _ ds _ (ThrowTo t) t2 a = t == t2 && canInterrupt ds t2 a
dependent memtype ds t1 a1 t2 a2 = case rewind a2 of
Just l2
| isSTM a1 && isSTM a2
-> not . S.null $ tvarsOf a1 `S.intersection` tvarsOf a2
| not (isBlock a1 && isBarrier (simplifyLookahead l2)) ->
dependent' memtype ds (t1, a1) (t2, l2)
dependent' memtype ds t1 a1 t2 l2
_ -> dependentActions memtype ds (simplifyAction a1) (simplifyAction a2)
where
@ -400,8 +372,8 @@ dependent memtype ds (t1, a1) (t2, a2) = case rewind a2 of
--
-- Termination of the initial thread is handled specially in the DPOR
-- implementation.
dependent' :: MemType -> DepState -> (ThreadId, ThreadAction) -> (ThreadId, Lookahead) -> Bool
dependent' memtype ds (t1, a1) (t2, l2) = case (a1, l2) of
dependent' :: MemType -> DepState -> ThreadId -> ThreadAction -> ThreadId -> Lookahead -> Bool
dependent' memtype ds t1 a1 t2 l2 = case (a1, l2) of
-- Worst-case assumption: all IO is dependent.
(LiftIO, WillLiftIO) -> True
@ -496,6 +468,7 @@ yieldCount tid ts l = go initialThread ts where
| t == tid && willYield l = 1
| otherwise = 0
{-# INLINE go' #-}
go' t t' act rest
| t == tid && didYield act = 1 + go t' rest
| otherwise = go t' rest
@ -505,10 +478,14 @@ yieldCount tid ts l = go initialThread ts where
maxYieldCountDiff :: [(Decision, ThreadAction)]
-> Lookahead
-> Int
maxYieldCountDiff ts l = maximum yieldCountDiffs where
yieldsBy tid = yieldCount tid ts l
yieldCounts = [yieldsBy tid | tid <- nub $ allTids ts]
yieldCountDiffs = [y1 - y2 | y1 <- yieldCounts, y2 <- yieldCounts]
maxYieldCountDiff ts l = go 0 yieldCounts where
go m (yc:ycs) =
let m' = m `max` foldl' (go' yc) 0 ycs
in go m' ycs
go m [] = m
go' yc0 m yc = m `max` abs (yc0 - yc)
yieldCounts = [yieldCount t ts l | t <- allTids ts]
-- All the threads created during the lifetime of the system.
allTids ((_, act):rest) =

View File

@ -11,14 +11,14 @@
-- interface of this library.
module Test.DejaFu.SCT.Internal where
import Control.DeepSeq (NFData(..), force)
import Control.Exception (MaskingState(..))
import Data.Char (ord)
import Data.List (foldl', intercalate, partition, sortBy)
import Data.Function (on)
import qualified Data.Foldable as F
import Data.List (intercalate, nubBy, partition, sortOn)
import Data.List.NonEmpty (NonEmpty(..), toList)
import Data.Ord (Down(..), comparing)
import Data.Map.Strict (Map)
import Data.Maybe (fromJust, isJust, isNothing, mapMaybe)
import Data.Maybe (catMaybes, fromJust, isNothing)
import qualified Data.Map.Strict as M
import Data.Set (Set)
import qualified Data.Set as S
@ -51,16 +51,7 @@ data DPOR = DPOR
, dporAction :: Maybe ThreadAction
-- ^ What happened at this step. This will be 'Nothing' at the root,
-- 'Just' everywhere else.
}
instance NFData DPOR where
rnf dpor = rnf ( dporRunnable dpor
, dporTodo dpor
, dporDone dpor
, dporSleep dpor
, dporTaken dpor
, dporAction dpor
)
} deriving Show
-- | One step of the execution, including information for backtracking
-- purposes. This backtracking information is used to generate new
@ -68,7 +59,9 @@ instance NFData DPOR where
data BacktrackStep = BacktrackStep
{ bcktThreadid :: ThreadId
-- ^ The thread running at this step
, bcktDecision :: (Decision, ThreadAction)
, bcktDecision :: Decision
-- ^ What was decided at this step.
, bcktAction :: ThreadAction
-- ^ What happened at this step.
, bcktRunnable :: Map ThreadId Lookahead
-- ^ The threads runnable at this step
@ -77,15 +70,7 @@ data BacktrackStep = BacktrackStep
-- alternatives were added conservatively due to the bound.
, bcktState :: DepState
-- ^ Some domain-specific state at this point.
}
instance NFData BacktrackStep where
rnf b = rnf ( bcktThreadid b
, bcktDecision b
, bcktRunnable b
, bcktBacktracks b
, bcktState b
)
} deriving Show
-- | Initial DPOR state, given an initial thread ID. This initial
-- thread should exist and be runnable at the start of execution.
@ -127,24 +112,17 @@ findSchedulePrefix predicate idx dpor0
(ts, c, slp) = allPrefixes !! i
in Just (ts, c, slp, g)
where
allPrefixes = go (initialDPORThread dpor0) dpor0
allPrefixes = go dpor0
go tid dpor =
go dpor =
-- All the possible prefix traces from this point, with
-- updated DPOR subtrees if taken from the done list.
let prefixes = concatMap go' (M.toList $ dporDone dpor) ++ here dpor
-- Sort by number of preemptions, in descending order.
cmp = Down . preEmps tid dpor . (\(a,_,_) -> a)
sorted = sortBy (comparing cmp) prefixes
in if null prefixes
then []
else case partition (\(t:_,_,_) -> predicate t) sorted of
([], []) -> err "findSchedulePrefix" "empty prefix list!"
let prefixes = here dpor : map go' (M.toList $ dporDone dpor)
in case concatPartition (\(t:_,_,_) -> predicate t) prefixes of
([], choices) -> choices
(choices, _) -> choices
go' (tid, dpor) = (\(ts,c,slp) -> (tid:ts,c,slp)) <$> go tid dpor
go' (tid, dpor) = (\(ts,c,slp) -> (tid:ts,c,slp)) <$> go dpor
-- Prefix traces terminating with a to-do decision at this point.
here dpor = [([t], c, sleeps dpor) | (t, c) <- M.toList $ dporTodo dpor]
@ -154,16 +132,10 @@ findSchedulePrefix predicate idx dpor0
-- explored.
sleeps dpor = dporSleep dpor `M.union` dporTaken dpor
-- The number of pre-emptive context switches
preEmps tid dpor (t:ts) =
let rest = preEmps t (fromJust . M.lookup t $ dporDone dpor) ts
in if tid `S.member` dporRunnable dpor then 1 + rest else rest
preEmps _ _ [] = 0::Int
-- | Add a new trace to the tree, creating a new subtree branching off
-- at the point where the \"to-do\" decision was made.
incorporateTrace
:: (DepState -> (ThreadId, ThreadAction) -> (ThreadId, ThreadAction) -> Bool)
:: (DepState -> ThreadId -> ThreadAction -> ThreadId -> ThreadAction -> Bool)
-- ^ Dependency function
-> Bool
-- ^ Whether the \"to-do\" point which was used to create this new
@ -176,7 +148,7 @@ incorporateTrace
incorporateTrace dependency conservative trace dpor0 = grow initialDepState (initialDPORThread dpor0) trace dpor0 where
grow state tid trc@((d, _, a):rest) dpor =
let tid' = tidOf tid d
state' = updateDepState state (tid', a)
state' = updateDepState state tid' a
in case M.lookup tid' (dporDone dpor) of
Just dpor' ->
let done = M.insert tid' (grow state' tid' rest dpor') (dporDone dpor)
@ -193,8 +165,8 @@ incorporateTrace dependency conservative trace dpor0 = grow initialDepState (ini
-- Construct a new subtree corresponding to a trace suffix.
subtree state tid sleep ((_, _, a):rest) =
let state' = updateDepState state (tid, a)
sleep' = M.filterWithKey (\t a' -> not $ dependency state' (tid, a) (t,a')) sleep
let state' = updateDepState state tid a
sleep' = M.filterWithKey (\t a' -> not $ dependency state' tid a t a') sleep
in DPOR
{ dporRunnable = S.fromList $ case rest of
((_, runnable, _):_) -> map fst runnable
@ -225,7 +197,7 @@ incorporateTrace dependency conservative trace dpor0 = grow initialDepState (ini
-- runnable, a dependency is imposed between this final action and
-- everything else.
findBacktrackSteps
:: (DepState -> (ThreadId, ThreadAction) -> (ThreadId, Lookahead) -> Bool)
:: (DepState -> ThreadId -> ThreadAction -> ThreadId -> Lookahead -> Bool)
-- ^ Dependency function.
-> BacktrackFunc
-- ^ Backtracking function. Given a list of backtracking points, and
@ -244,17 +216,16 @@ findBacktrackSteps
-> Trace
-- ^ The execution trace.
-> [BacktrackStep]
findBacktrackSteps _ _ _ bcktrck
| Sq.null bcktrck = const []
findBacktrackSteps dependency backtrack boundKill bcktrck = go initialDepState S.empty initialThread [] (Sq.viewl bcktrck) where
findBacktrackSteps dependency backtrack boundKill = go initialDepState S.empty initialThread [] . F.toList where
-- Walk through the traces one step at a time, building up a list of
-- new backtracking points.
go state allThreads tid bs ((e,i):<is) ((d,_,a):ts) =
go state allThreads tid bs ((e,i):is) ((d,_,a):ts) =
let tid' = tidOf tid d
state' = updateDepState state (tid', a)
state' = updateDepState state tid' a
this = BacktrackStep
{ bcktThreadid = tid'
, bcktDecision = (d, a)
, bcktDecision = d
, bcktAction = a
, bcktRunnable = M.fromList . toList $ e
, bcktBacktracks = M.fromList $ map (\i' -> (i', False)) i
, bcktState = state'
@ -263,30 +234,41 @@ findBacktrackSteps dependency backtrack boundKill bcktrck = go initialDepState S
runnable = S.fromList (M.keys $ bcktRunnable this)
allThreads' = allThreads `S.union` runnable
killsEarly = null ts && boundKill
in go state' allThreads' tid' bs' (Sq.viewl is) ts
in go state' allThreads' tid' bs' is ts
go _ _ _ bs _ _ = bs
-- Find the prior actions dependent with this one and add
-- backtracking points.
doBacktrack killsEarly allThreads enabledThreads bs =
let tagged = reverse $ zip [0..] bs
idxs = [ (head is, u)
idxs = [ (head is, False, u)
| (u, n) <- enabledThreads
, v <- S.toList allThreads
, u /= v
, let is = idxs' u n v tagged
, not $ null is]
idxs' u n v = mapMaybe go' where
go' (i, b)
idxs' u n v = catMaybes . go' True where
{-# INLINE go' #-}
go' final ((i,b):rest)
-- Don't cross subconcurrency boundaries
| isSubC final b = []
-- If this is the final action in the trace and the
-- execution was killed due to nothing being within bounds
-- (@killsEarly == True@) assume worst-case dependency.
| bcktThreadid b == v && (killsEarly || isDependent b) = Just i
| otherwise = Nothing
| bcktThreadid b == v && (killsEarly || isDependent b) = Just i : go' False rest
| otherwise = go' False rest
go' _ [] = []
isDependent b = dependency (bcktState b) (bcktThreadid b, snd $ bcktDecision b) (u, n)
in foldl' (\b (i, u) -> backtrack b i u) bs idxs
{-# INLINE isSubC #-}
isSubC final b = case bcktAction b of
Stop -> not final && bcktThreadid b == initialThread
Subconcurrency -> bcktThreadid b == initialThread
_ -> False
{-# INLINE isDependent #-}
isDependent b = dependency (bcktState b) (bcktThreadid b) (bcktAction b) u n
in backtrack bs idxs
-- | Add new backtracking points, if they have not already been
-- visited, fit into the bound, and aren't in the sleep set.
@ -302,10 +284,9 @@ incorporateBacktrackSteps bv = go Nothing [] where
go priorTid pref (b:bs) bpor =
let bpor' = doBacktrack priorTid pref b bpor
tid = bcktThreadid b
pref' = pref ++ [bcktDecision b]
pref' = pref ++ [(bcktDecision b, bcktAction b)]
child = go (Just tid) pref' bs . fromJust $ M.lookup tid (dporDone bpor)
in bpor' { dporDone = M.insert tid child $ dporDone bpor' }
go _ _ [] bpor = bpor
doBacktrack priorTid pref b bpor =
@ -343,16 +324,7 @@ data SchedState = SchedState
, schedDepState :: DepState
-- ^ State used by the dependency function to determine when to
-- remove decisions from the sleep set.
}
instance NFData SchedState where
rnf s = rnf ( schedSleep s
, schedPrefix s
, schedBPoints s
, schedIgnore s
, schedBoundKill s
, schedDepState s
)
} deriving Show
-- | Initial scheduler state for a given prefix
initialSchedState :: Map ThreadId ThreadAction
@ -378,52 +350,60 @@ type BoundFunc
-- | A backtracking step is a point in the execution where another
-- decision needs to be made, in order to explore interesting new
-- schedules. A backtracking /function/ takes the steps identified so
-- far and a point and a thread to backtrack to, and inserts at least
-- that backtracking point. More may be added to compensate for the
-- effects of the bounding function. For example, under pre-emption
-- bounding a conservative backtracking point is added at the prior
-- context switch.
-- far and a list of points and thread at that point to backtrack
-- to. More points be added to compensate for the effects of the
-- bounding function. For example, under pre-emption bounding a
-- conservative backtracking point is added at the prior context
-- switch. The bool is whether the point is conservative. Conservative
-- points are always explored, whereas non-conservative ones might be
-- skipped based on future information.
--
-- In general, a backtracking function should identify one or more
-- backtracking points, and then use @backtrackAt@ to do the actual
-- work.
type BacktrackFunc
= [BacktrackStep] -> Int -> ThreadId -> [BacktrackStep]
= [BacktrackStep] -> [(Int, Bool, ThreadId)] -> [BacktrackStep]
-- | Add a backtracking point. If the thread isn't runnable, add all
-- runnable threads. If the backtracking point is already present,
-- don't re-add it UNLESS this would make it conservative.
backtrackAt
:: (BacktrackStep -> Bool)
:: (ThreadId -> BacktrackStep -> Bool)
-- ^ If this returns @True@, backtrack to all runnable threads,
-- rather than just the given thread.
-> Bool
-- ^ Is this backtracking point conservative? Conservative points
-- are always explored, whereas non-conservative ones might be
-- skipped based on future information.
-> BacktrackFunc
backtrackAt toAll conservative bs i tid = go bs i where
go bx@(b:rest) 0
backtrackAt toAll bs0 = backtrackAt' . nubBy ((==) `on` fst') . sortOn fst' where
fst' (x,_,_) = x
backtrackAt' ((i,c,t):is) = go i bs0 i c t is
backtrackAt' [] = bs0
go i0 (b:bs) 0 c tid is
-- If the backtracking point is already present, don't re-add it,
-- UNLESS this would force it to backtrack (it's conservative)
-- where before it might not.
| not (toAll b) && tid `M.member` bcktRunnable b =
| not (toAll tid b) && tid `M.member` bcktRunnable b =
let val = M.lookup tid $ bcktBacktracks b
in if isNothing val || (val == Just False && conservative)
then b { bcktBacktracks = backtrackTo b } : rest
else bx
b' = if isNothing val || (val == Just False && c)
then b { bcktBacktracks = backtrackTo tid c b }
else b
in b' : case is of
((i',c',t'):is') -> go i' bs (i'-i0-1) c' t' is'
[] -> bs
-- Otherwise just backtrack to everything runnable.
| otherwise = b { bcktBacktracks = backtrackAll b } : rest
go (b:rest) n = b : go rest (n-1)
go [] _ = error "backtrackAt: Ran out of schedule whilst backtracking!"
| otherwise =
let b' = b { bcktBacktracks = backtrackAll c b }
in b' : case is of
((i',c',t'):is') -> go i' bs (i'-i0-1) c' t' is'
[] -> bs
go i0 (b:bs) i c tid is = b : go i0 bs (i-1) c tid is
go _ [] _ _ _ _ = err "backtrackAt" "ran out of schedule whilst backtracking!"
-- Backtrack to a single thread
backtrackTo = M.insert tid conservative . bcktBacktracks
backtrackTo tid c = M.insert tid c . bcktBacktracks
-- Backtrack to all runnable threads
backtrackAll = M.map (const conservative) . bcktRunnable
backtrackAll c = M.map (const c) . bcktRunnable
-- | DPOR scheduler: takes a list of decisions, and maintains a trace
-- including the runnable threads, and the alternative choices allowed
@ -433,17 +413,14 @@ backtrackAt toAll conservative bs i tid = go bs i where
-- the prior thread if it's (1) still runnable and (2) hasn't just
-- yielded. Furthermore, threads which /will/ yield are ignored in
-- preference of those which will not.
--
-- This forces full evaluation of the result every step, to avoid any
-- possible space leaks.
dporSched
:: (DepState -> (ThreadId, ThreadAction) -> (ThreadId, ThreadAction) -> Bool)
:: (DepState -> ThreadId -> ThreadAction -> ThreadId -> ThreadAction -> Bool)
-- ^ Dependency function.
-> BoundFunc
-- ^ Bound function: returns true if that schedule prefix terminated
-- with the lookahead decision fits within the bound.
-> Scheduler SchedState
dporSched dependency inBound trc prior threads s = force schedule where
dporSched dependency inBound trc prior threads s = schedule where
-- Pick a thread to run.
schedule = case schedPrefix s of
-- If there is a decision available, make it
@ -455,7 +432,7 @@ dporSched dependency inBound trc prior threads s = force schedule where
[] ->
let choices = restrictToBound initialise
checkDep t a = case prior of
Just (tid, act) -> dependency (schedDepState s) (tid, act) (t, a)
Just (tid, act) -> dependency (schedDepState s) tid act t a
Nothing -> False
ssleep' = M.filterWithKey (\t a -> not $ checkDep t a) $ schedSleep s
choices' = filter (`notElem` M.keys ssleep') choices
@ -470,7 +447,7 @@ dporSched dependency inBound trc prior threads s = force schedule where
{ schedBPoints = schedBPoints s |> (threads, rest)
, schedDepState = nextDepState
}
nextDepState = let ds = schedDepState s in maybe ds (updateDepState ds) prior
nextDepState = let ds = schedDepState s in maybe ds (uncurry $ updateDepState ds) prior
-- Pick a new thread to run, not considering bounds. Choose the
-- current thread if available and it hasn't just yielded, otherwise
@ -537,11 +514,7 @@ data DepState = DepState
-- the masking state is assumed to be @Unmasked@. This nicely
-- provides compatibility with dpor-0.1, where the thread IDs are
-- not available.
}
instance NFData DepState where
-- Cheats: 'MaskingState' has no 'NFData' instance.
rnf ds = rnf (depCRState ds, M.keys (depMaskState ds))
} deriving (Eq, Show)
-- | Initial dependency state.
initialDepState :: DepState
@ -549,8 +522,8 @@ initialDepState = DepState M.empty M.empty
-- | Update the 'CRef' buffer state with the action that has just
-- happened.
updateDepState :: DepState -> (ThreadId, ThreadAction) -> DepState
updateDepState depstate (tid, act) = DepState
updateDepState :: DepState -> ThreadId -> ThreadAction -> DepState
updateDepState depstate tid act = DepState
{ depCRState = updateCRState act $ depCRState depstate
, depMaskState = updateMaskState tid act $ depMaskState depstate
}
@ -698,3 +671,17 @@ toDotFiltered check showTid showAct = digraph . go "L" where
-- | Internal errors.
err :: String -> String -> a
err func msg = error (func ++ ": (internal error) " ++ msg)
-- | A combination of 'partition' and 'concat'.
concatPartition :: (a -> Bool) -> [[a]] -> ([a], [a])
{-# INLINE concatPartition #-}
-- note: `foldr (flip (foldr select))` is slow, as is `foldl (foldl
-- select))`, and `foldl'` variants. The sweet spot seems to be `foldl
-- (foldr select)` for some reason I don't really understand.
concatPartition p = foldl (foldr select) ([], []) where
-- Lazy pattern matching, got this trick from the 'partition'
-- implementation. This reduces allocation fairly significantly; I
-- do not know why.
select a ~(ts, fs)
| p a = (a:ts, fs)
| otherwise = (ts, a:fs)

View File

@ -96,7 +96,6 @@ library
build-depends: base >=4.8 && <5
, concurrency >=1.0 && <1.1
, containers >=0.5 && <0.6
, deepseq >=1.3 && <1.5
, exceptions >=0.7 && <0.9
, monad-loops >=0.4 && <0.5
, mtl >=2.2 && <2.3