diff --git a/dejafu-tests/Cases/MultiThreaded.hs b/dejafu-tests/Cases/MultiThreaded.hs index e84b940..43e602b 100644 --- a/dejafu-tests/Cases/MultiThreaded.hs +++ b/dejafu-tests/Cases/MultiThreaded.hs @@ -11,6 +11,7 @@ import Test.HUnit.DejaFu (testDejafu) import Control.Concurrent.Classy import Control.Monad.STM.Class +import Test.DejaFu.Conc (Conc, subconcurrency) #if __GLASGOW_HASKELL__ < 710 import Control.Applicative ((<$>), (<*>)) @@ -49,6 +50,13 @@ tests = , testGroup "Daemons" . hUnitTestToTests $ test [ testDejafu schedDaemon "schedule daemon" $ gives' [0,1] ] + + , testGroup "Subconcurrency" . hUnitTestToTests $ test + [ testDejafu scDeadlock1 "deadlock1" $ gives' [Left Deadlock, Right ()] + , testDejafu scDeadlock2 "deadlock2" $ gives' [(Left Deadlock, ()), (Right (), ())] + , testDejafu scSuccess "success" $ gives' [Right ()] + , testDejafu scIllegal "illegal" $ gives [Left IllegalSubconcurrency] + ] ] -------------------------------------------------------------------------------- @@ -207,3 +215,40 @@ schedDaemon = do x <- newCRef 0 _ <- fork $ myThreadId >> writeCRef x 1 readCRef x + +-------------------------------------------------------------------------------- +-- Subconcurrency + +-- | Subcomputation deadlocks sometimes. +scDeadlock1 :: Monad n => Conc n r (Either Failure ()) +scDeadlock1 = do + var <- newEmptyMVar + subconcurrency $ do + void . fork $ putMVar var () + putMVar var () + +-- | Subcomputation deadlocks sometimes, and action after it still +-- happens. +scDeadlock2 :: Monad n => Conc n r (Either Failure (), ()) +scDeadlock2 = do + var <- newEmptyMVar + res <- subconcurrency $ do + void . fork $ putMVar var () + putMVar var () + (,) <$> pure res <*> readMVar var + +-- | Subcomputation successfully completes. +scSuccess :: Monad n => Conc n r (Either Failure ()) +scSuccess = do + var <- newMVar () + subconcurrency $ do + out <- newEmptyMVar + void . fork $ takeMVar var >>= putMVar out + takeMVar out + +-- | Illegal usage +scIllegal :: Monad n => Conc n r () +scIllegal = do + var <- newEmptyMVar + void . fork $ readMVar var + void . subconcurrency $ pure () diff --git a/dejafu-tests/Cases/SingleThreaded.hs b/dejafu-tests/Cases/SingleThreaded.hs index 8ea54ee..698f0d7 100644 --- a/dejafu-tests/Cases/SingleThreaded.hs +++ b/dejafu-tests/Cases/SingleThreaded.hs @@ -3,6 +3,7 @@ module Cases.SingleThreaded where import Control.Exception (ArithException(..), ArrayException(..)) +import Control.Monad (void) import Test.DejaFu (Failure(..), gives, gives') import Test.Framework (Test, testGroup) import Test.Framework.Providers.HUnit (hUnitTestToTests) @@ -10,7 +11,7 @@ import Test.HUnit (test) import Test.HUnit.DejaFu (testDejafu) import Control.Concurrent.Classy -import Control.Monad.STM.Class +import Test.DejaFu.Conc (Conc, subconcurrency) import Utils @@ -58,6 +59,12 @@ tests = [ testDejafu capsGet "get" $ gives' [True] , testDejafu capsSet "set" $ gives' [True] ] + + , testGroup "Subconcurrency" . hUnitTestToTests $ test + [ testDejafu scDeadlock1 "deadlock1" $ gives' [Left Deadlock] + , testDejafu scDeadlock2 "deadlock2" $ gives' [(Left Deadlock, ())] + , testDejafu scSuccess "success" $ gives' [Right ()] + ] ] -------------------------------------------------------------------------------- @@ -252,3 +259,22 @@ capsSet = do caps <- getNumCapabilities setNumCapabilities $ caps + 1 (== caps + 1) <$> getNumCapabilities + +-------------------------------------------------------------------------------- +-- Subconcurrency + +-- | Subcomputation deadlocks. +scDeadlock1 :: Monad n => Conc n r (Either Failure ()) +scDeadlock1 = subconcurrency (newEmptyMVar >>= readMVar) + +-- | Subcomputation deadlocks, and action after it still happens. +scDeadlock2 :: Monad n => Conc n r (Either Failure (), ()) +scDeadlock2 = do + var <- newMVar () + (,) <$> subconcurrency (putMVar var ()) <*> readMVar var + +-- | Subcomputation successfully completes. +scSuccess :: Monad n => Conc n r (Either Failure ()) +scSuccess = do + var <- newMVar () + subconcurrency (takeMVar var) diff --git a/dejafu/Test/DejaFu.hs b/dejafu/Test/DejaFu.hs index 51c3a2a..28af201 100644 --- a/dejafu/Test/DejaFu.hs +++ b/dejafu/Test/DejaFu.hs @@ -240,7 +240,6 @@ module Test.DejaFu ) where import Control.Arrow (first) -import Control.DeepSeq (NFData(..)) import Control.Monad (when, unless) import Control.Monad.Ref (MonadRef) import Control.Monad.ST (runST) @@ -400,9 +399,6 @@ defaultFail failures = Result False 0 failures "" defaultPass :: Result a defaultPass = Result True 0 [] "" -instance NFData a => NFData (Result a) where - rnf r = rnf (_pass r, _casesChecked r, _failures r, _failureMsg r) - instance Functor Result where fmap f r = r { _failures = map (first $ fmap f) $ _failures r } diff --git a/dejafu/Test/DejaFu/Common.hs b/dejafu/Test/DejaFu/Common.hs index 3a17a9e..ae67cfe 100644 --- a/dejafu/Test/DejaFu/Common.hs +++ b/dejafu/Test/DejaFu/Common.hs @@ -60,7 +60,6 @@ module Test.DejaFu.Common , MemType(..) ) where -import Control.DeepSeq (NFData(..)) import Control.Exception (MaskingState(..)) import Data.Dynamic (Dynamic) import Data.List (sort, nub, intercalate) @@ -83,9 +82,6 @@ instance Show ThreadId where show (ThreadId (Just n) _) = n show (ThreadId Nothing i) = show i -instance NFData ThreadId where - rnf (ThreadId n i) = rnf (n, i) - -- | Every @CRef@ has a unique identifier. data CRefId = CRefId (Maybe String) Int deriving Eq @@ -97,9 +93,6 @@ instance Show CRefId where show (CRefId (Just n) _) = n show (CRefId Nothing i) = show i -instance NFData CRefId where - rnf (CRefId n i) = rnf (n, i) - -- | Every @MVar@ has a unique identifier. data MVarId = MVarId (Maybe String) Int deriving Eq @@ -111,9 +104,6 @@ instance Show MVarId where show (MVarId (Just n) _) = n show (MVarId Nothing i) = show i -instance NFData MVarId where - rnf (MVarId n i) = rnf (n, i) - -- | Every @TVar@ has a unique identifier. data TVarId = TVarId (Maybe String) Int deriving Eq @@ -125,9 +115,6 @@ instance Show TVarId where show (TVarId (Just n) _) = n show (TVarId Nothing i) = show i -instance NFData TVarId where - rnf (TVarId n i) = rnf (n, i) - -- | The ID of the initial thread. initialThread :: ThreadId initialThread = ThreadId (Just "main") 0 @@ -272,38 +259,10 @@ data ThreadAction = -- ^ A '_concMessage' annotation was processed. | Stop -- ^ Cease execution and terminate. + | Subconcurrency + -- ^ Start executing an action with @subconcurrency@. deriving Show -instance NFData ThreadAction where - rnf (Fork t) = rnf t - rnf (GetNumCapabilities i) = rnf i - rnf (SetNumCapabilities i) = rnf i - rnf (NewVar c) = rnf c - rnf (PutVar c ts) = rnf (c, ts) - rnf (BlockedPutVar c) = rnf c - rnf (TryPutVar c b ts) = rnf (c, b, ts) - rnf (ReadVar c) = rnf c - rnf (BlockedReadVar c) = rnf c - rnf (TakeVar c ts) = rnf (c, ts) - rnf (BlockedTakeVar c) = rnf c - rnf (TryTakeVar c b ts) = rnf (c, b, ts) - rnf (NewRef c) = rnf c - rnf (ReadRef c) = rnf c - rnf (ReadRefCas c) = rnf c - rnf (ModRef c) = rnf c - rnf (ModRefCas c) = rnf c - rnf (WriteRef c) = rnf c - rnf (CasRef c b) = rnf (c, b) - rnf (CommitRef t c) = rnf (t, c) - rnf (STM s ts) = rnf (s, ts) - rnf (BlockedSTM s) = rnf s - rnf (ThrowTo t) = rnf t - rnf (BlockedThrowTo t) = rnf t - rnf (SetMasking b m) = b `seq` m `seq` () - rnf (ResetMasking b m) = b `seq` m `seq` () - rnf (Message m) = m `seq` () - rnf a = a `seq` () - -- | Check if a @ThreadAction@ immediately blocks. isBlock :: ThreadAction -> Bool isBlock (BlockedThrowTo _) = True @@ -401,28 +360,10 @@ data Lookahead = -- ^ Will process a _concMessage' annotation. | WillStop -- ^ Will cease execution and terminate. + | WillSubconcurrency + -- ^ Will execute an action with @subconcurrency@. deriving Show -instance NFData Lookahead where - rnf (WillSetNumCapabilities i) = rnf i - rnf (WillPutVar c) = rnf c - rnf (WillTryPutVar c) = rnf c - rnf (WillReadVar c) = rnf c - rnf (WillTakeVar c) = rnf c - rnf (WillTryTakeVar c) = rnf c - rnf (WillReadRef c) = rnf c - rnf (WillReadRefCas c) = rnf c - rnf (WillModRef c) = rnf c - rnf (WillModRefCas c) = rnf c - rnf (WillWriteRef c) = rnf c - rnf (WillCasRef c) = rnf c - rnf (WillCommitRef t c) = rnf (t, c) - rnf (WillThrowTo t) = rnf t - rnf (WillSetMasking b m) = b `seq` m `seq` () - rnf (WillResetMasking b m) = b `seq` m `seq` () - rnf (WillMessage m) = m `seq` () - rnf l = l `seq` () - -- | Convert a 'ThreadAction' into a 'Lookahead': \"rewind\" what has -- happened. 'Killed' has no 'Lookahead' counterpart. rewind :: ThreadAction -> Maybe Lookahead @@ -462,6 +403,7 @@ rewind LiftIO = Just WillLiftIO rewind Return = Just WillReturn rewind (Message m) = Just (WillMessage m) rewind Stop = Just WillStop +rewind Subconcurrency = Just WillSubconcurrency -- | Check if an operation could enable another thread. willRelease :: Lookahead -> Bool @@ -508,17 +450,6 @@ data ActionType = -- communication. deriving (Eq, Show) -instance NFData ActionType where - rnf (UnsynchronisedRead r) = rnf r - rnf (UnsynchronisedWrite r) = rnf r - rnf (PartiallySynchronisedCommit r) = rnf r - rnf (PartiallySynchronisedWrite r) = rnf r - rnf (PartiallySynchronisedModify r) = rnf r - rnf (SynchronisedModify r) = rnf r - rnf (SynchronisedRead c) = rnf c - rnf (SynchronisedWrite c) = rnf c - rnf a = a `seq` () - -- | Check if an action imposes a write barrier. isBarrier :: ActionType -> Bool isBarrier (SynchronisedModify _) = True @@ -611,13 +542,6 @@ data TAction = -- ^ Terminate successfully and commit effects. deriving (Eq, Show) -instance NFData TAction where - rnf (TRead v) = rnf v - rnf (TWrite v) = rnf v - rnf (TCatch s m) = rnf (s, m) - rnf (TOrElse s m) = rnf (s, m) - rnf a = a `seq` () - ------------------------------------------------------------------------------- -- Traces @@ -640,11 +564,6 @@ data Decision = -- ^ Pre-empt the running thread, and switch to another. deriving (Eq, Show) -instance NFData Decision where - rnf (Start tid) = rnf tid - rnf (SwitchTo tid) = rnf tid - rnf d = d `seq` () - -- | Pretty-print a trace, including a key of the thread IDs (not -- including thread 0). Each line of the key is indented by two -- spaces. @@ -675,15 +594,15 @@ showTrace trc = intercalate "\n" $ concatMap go trc : strkey where preEmpCount :: [(Decision, ThreadAction)] -> (Decision, Lookahead) -> Int -preEmpCount ts (d, _) = go initialThread Nothing ts where - go _ (Just Yield) ((SwitchTo t, a):rest) = go t (Just a) rest - go tid prior ((SwitchTo t, a):rest) +preEmpCount (x:xs) (d, _) = go initialThread x xs where + go _ (_, Yield) (r@(SwitchTo t, _):rest) = go t r rest + go tid prior (r@(SwitchTo t, _):rest) | isCommitThread t = go tid prior (skip rest) - | otherwise = 1 + go t (Just a) rest - go _ _ ((Start t, a):rest) = go t (Just a) rest - go tid _ ((Continue, a):rest) = go tid (Just a) rest + | otherwise = 1 + go t r rest + go _ _ (r@(Start t, _):rest) = go t r rest + go tid _ (r@(Continue, _):rest) = go tid r rest go _ prior [] = case (prior, d) of - (Just Yield, SwitchTo _) -> 0 + ((_, Yield), SwitchTo _) -> 0 (_, SwitchTo _) -> 1 _ -> 0 @@ -694,6 +613,7 @@ preEmpCount ts (d, _) = go initialThread Nothing ts where skip = dropWhile (not . isContextSwitch . fst) isContextSwitch Continue = False isContextSwitch _ = True +preEmpCount [] _ = 0 ------------------------------------------------------------------------------- -- Failures @@ -716,18 +636,19 @@ data Failure = -- ^ The computation became blocked indefinitely on @TVar@s. | UncaughtException -- ^ An uncaught exception bubbled to the top of the computation. + | IllegalSubconcurrency + -- ^ Calls to @subconcurrency@ were nested, or attempted when + -- multiple threads existed. deriving (Eq, Show, Read, Ord, Enum, Bounded) -instance NFData Failure where - rnf f = f `seq` () - -- | Pretty-print a failure showFail :: Failure -> String -showFail Abort = "[abort]" -showFail Deadlock = "[deadlock]" -showFail STMDeadlock = "[stm-deadlock]" -showFail InternalError = "[internal-error]" +showFail Abort = "[abort]" +showFail Deadlock = "[deadlock]" +showFail STMDeadlock = "[stm-deadlock]" +showFail InternalError = "[internal-error]" showFail UncaughtException = "[exception]" +showFail IllegalSubconcurrency = "[illegal-subconcurrency]" ------------------------------------------------------------------------------- -- Memory Models @@ -751,9 +672,6 @@ data MemType = -- created. deriving (Eq, Show, Read, Ord, Enum, Bounded) -instance NFData MemType where - rnf m = m `seq` () - ------------------------------------------------------------------------------- -- Utilities diff --git a/dejafu/Test/DejaFu/Conc.hs b/dejafu/Test/DejaFu/Conc.hs index 6ecdc25..58cae52 100755 --- a/dejafu/Test/DejaFu/Conc.hs +++ b/dejafu/Test/DejaFu/Conc.hs @@ -28,6 +28,7 @@ module Test.DejaFu.Conc , Failure(..) , MemType(..) , runConcurrent + , subconcurrency -- * Execution traces , Trace @@ -49,12 +50,11 @@ import Control.Exception (MaskingState(..)) import qualified Control.Monad.Base as Ba import qualified Control.Monad.Catch as Ca import qualified Control.Monad.IO.Class as IO -import Control.Monad.Ref (MonadRef, newRef, readRef, writeRef) +import Control.Monad.Ref (MonadRef,) import Control.Monad.ST (ST) import Data.Dynamic (toDyn) +import qualified Data.Foldable as F import Data.IORef (IORef) -import qualified Data.Map.Strict as M -import Data.Maybe (fromJust) import Data.STRef (STRef) import Test.DejaFu.Schedule @@ -62,13 +62,12 @@ import qualified Control.Monad.Conc.Class as C import Test.DejaFu.Common import Test.DejaFu.Conc.Internal import Test.DejaFu.Conc.Internal.Common -import Test.DejaFu.Conc.Internal.Threading import Test.DejaFu.STM {-# ANN module ("HLint: ignore Avoid lambda" :: String) #-} {-# ANN module ("HLint: ignore Use const" :: String) #-} -newtype Conc n r a = C { unC :: M n r (STMLike n r) a } deriving (Functor, Applicative, Monad) +newtype Conc n r a = C { unC :: M n r a } deriving (Functor, Applicative, Monad) -- | A 'MonadConc' implementation using @ST@, this should be preferred -- if you do not need 'liftIO'. @@ -77,10 +76,10 @@ type ConcST t = Conc (ST t) (STRef t) -- | A 'MonadConc' implementation using @IO@. type ConcIO = Conc IO IORef -toConc :: ((a -> Action n r (STMLike n r)) -> Action n r (STMLike n r)) -> Conc n r a +toConc :: ((a -> Action n r) -> Action n r) -> Conc n r a toConc = C . cont -wrap :: (M n r (STMLike n r) a -> M n r (STMLike n r) a) -> Conc n r a -> Conc n r a +wrap :: (M n r a -> M n r a) -> Conc n r a -> Conc n r a wrap f = C . f . unC instance IO.MonadIO ConcIO where @@ -181,20 +180,15 @@ runConcurrent :: MonadRef r n -> s -> Conc n r a -> n (Either Failure a, s, Trace) -runConcurrent sched memtype s (C conc) = do - ref <- newRef Nothing +runConcurrent sched memtype s ma = do + (res, s', trace) <- runConcurrency sched memtype s (unC ma) + pure (res, s', F.toList trace) - let c = runCont conc (AStop . writeRef ref . Just . Right) - let threads = launch' Unmasked initialThread (const c) M.empty - - (s', trace) <- runThreads runTransaction - sched - memtype - s - threads - initialIdSource - ref - - out <- readRef ref - - pure (fromJust out, s', reverse trace) +-- | Run a concurrent computation and return its result. +-- +-- This can only be called in the main thread, when no other threads +-- exist. Calls to 'subconcurrency' cannot be nested. Violating either +-- of these conditions will result in the computation failing with +-- @IllegalSubconcurrency@. +subconcurrency :: Conc n r a -> Conc n r (Either Failure a) +subconcurrency ma = toConc (ASub (unC ma)) diff --git a/dejafu/Test/DejaFu/Conc/Internal.hs b/dejafu/Test/DejaFu/Conc/Internal.hs index b3a2af7..7b821e8 100755 --- a/dejafu/Test/DejaFu/Conc/Internal.hs +++ b/dejafu/Test/DejaFu/Conc/Internal.hs @@ -16,19 +16,23 @@ module Test.DejaFu.Conc.Internal where import Control.Exception (MaskingState(..), toException) -import Control.Monad.Ref (MonadRef, newRef, writeRef) +import Control.Monad.Ref (MonadRef, newRef, readRef, writeRef) +import qualified Data.Foldable as F import Data.Functor (void) import Data.List (sort) import Data.List.NonEmpty (NonEmpty(..), fromList) import qualified Data.Map.Strict as M -import Data.Maybe (fromJust, isJust, isNothing, listToMaybe) +import Data.Maybe (fromJust, isJust, isNothing) +import Data.Monoid ((<>)) +import Data.Sequence (Seq, (<|)) +import qualified Data.Sequence as Seq import Test.DejaFu.Common import Test.DejaFu.Conc.Internal.Common import Test.DejaFu.Conc.Internal.Memory import Test.DejaFu.Conc.Internal.Threading import Test.DejaFu.Schedule -import Test.DejaFu.STM (Result(..)) +import Test.DejaFu.STM (Result(..), runTransaction) {-# ANN module ("HLint: ignore Use record patterns" :: String) #-} {-# ANN module ("HLint: ignore Use const" :: String) #-} @@ -36,40 +40,69 @@ import Test.DejaFu.STM (Result(..)) -------------------------------------------------------------------------------- -- * Execution +-- | 'Trace' but as a sequence. +type SeqTrace + = Seq (Decision, [(ThreadId, NonEmpty Lookahead)], ThreadAction) + +-- | Run a concurrent computation with a given 'Scheduler' and initial +-- state, returning a failure reason on error. Also returned is the +-- final state of the scheduler, and an execution trace. +runConcurrency :: MonadRef r n + => Scheduler g + -> MemType + -> g + -> M n r a + -> n (Either Failure a, g, SeqTrace) +runConcurrency sched memtype g ma = do + ref <- newRef Nothing + + let c = runCont ma (AStop . writeRef ref . Just . Right) + let threads = launch' Unmasked initialThread (const c) M.empty + let ctx = Context { cSchedState = g, cIdSource = initialIdSource, cThreads = threads, cWriteBuf = emptyBuffer, cCaps = 2 } + + (finalCtx, trace) <- runThreads sched memtype ref ctx + out <- readRef ref + pure (fromJust out, cSchedState finalCtx, trace) + +-- | The context a collection of threads are running in. +data Context n r g = Context + { cSchedState :: g + , cIdSource :: IdSource + , cThreads :: Threads n r + , cWriteBuf :: WriteBuffer r + , cCaps :: Int + } + -- | Run a collection of threads, until there are no threads left. --- --- Note: this returns the trace in reverse order, because it's more --- efficient to prepend to a list than append. As this function isn't --- exposed to users of the library, this is just an internal gotcha to --- watch out for. -runThreads :: MonadRef r n => (forall x. s x -> IdSource -> n (Result x, IdSource, TTrace)) - -> Scheduler g -> MemType -> g -> Threads n r s -> IdSource -> r (Maybe (Either Failure a)) -> n (g, Trace) -runThreads runstm sched memtype origg origthreads idsrc ref = go idsrc [] Nothing origg origthreads emptyBuffer 2 where - go idSource sofar prior g threads wb caps - | isTerminated = stop g - | isDeadlocked = die g Deadlock - | isSTMLocked = die g STMDeadlock - | isAborted = die g' Abort - | isNonexistant = die g' InternalError - | isBlocked = die g' InternalError +runThreads :: MonadRef r n + => Scheduler g -> MemType -> r (Maybe (Either Failure a)) -> Context n r g -> n (Context n r g, SeqTrace) +runThreads sched memtype ref = go Seq.empty [] Nothing where + -- sofar is the 'SeqTrace', sofarSched is the @[(Decision, + -- ThreadAction)]@ trace the scheduler needs. + go sofar sofarSched prior ctx + | isTerminated = stop ctx + | isDeadlocked = die Deadlock ctx + | isSTMLocked = die STMDeadlock ctx + | isAborted = die Abort $ ctx { cSchedState = g' } + | isNonexistant = die InternalError $ ctx { cSchedState = g' } + | isBlocked = die InternalError $ ctx { cSchedState = g' } | otherwise = do - stepped <- stepThread runstm memtype (_continuation $ fromJust thread) idSource chosen threads wb caps + stepped <- stepThread sched memtype chosen (_continuation $ fromJust thread) $ ctx { cSchedState = g' } case stepped of - Right (threads', idSource', act, wb', caps') -> loop threads' idSource' act wb' caps' - + Right (ctx', actOrTrc) -> loop actOrTrc ctx' Left UncaughtException - | chosen == initialThread -> die g' UncaughtException - | otherwise -> loop (kill chosen threads) idSource Killed wb caps - - Left failure -> die g' failure + | chosen == initialThread -> die UncaughtException $ ctx { cSchedState = g' } + | otherwise -> loop (Right Killed) $ ctx { cThreads = kill chosen threadsc, cSchedState = g' } + Left failure -> die failure $ ctx { cSchedState = g' } where - (choice, g') = sched (map (\(d,_,a) -> (d,a)) $ reverse sofar) ((\p (_,_,a) -> (p,a)) <$> prior <*> listToMaybe sofar) (fromList $ map (\(t,l:|_) -> (t,l)) runnable') g + (choice, g') = sched sofarSched prior (fromList $ map (\(t,l:|_) -> (t,l)) runnable') (cSchedState ctx) chosen = fromJust choice runnable' = [(t, nextActions t) | t <- sort $ M.keys runnable] runnable = M.filter (isNothing . _blocking) threadsc thread = M.lookup chosen threadsc - threadsc = addCommitThreads wb threads + threadsc = addCommitThreads (cWriteBuf ctx) threads + threads = cThreads ctx isAborted = isNothing choice isBlocked = isJust . _blocking $ fromJust thread isNonexistant = isNothing thread @@ -87,299 +120,262 @@ runThreads runstm sched memtype origg origthreads idsrc ref = go idsrc [] Nothin _ -> thrd decision - | Just chosen == prior = Continue - | prior `notElem` map (Just . fst) runnable' = Start chosen + | Just chosen == (fst <$> prior) = Continue + | (fst <$> prior) `notElem` map (Just . fst) runnable' = Start chosen | otherwise = SwitchTo chosen nextActions t = lookahead . _continuation . fromJust $ M.lookup t threadsc - stop outg = pure (outg, sofar) - die outg reason = writeRef ref (Just $ Left reason) >> stop outg + stop finalCtx = pure (finalCtx, sofar) + die reason finalCtx = writeRef ref (Just $ Left reason) >> stop finalCtx - loop threads' idSource' act wb' = - let sofar' = ((decision, runnable', act) : sofar) - threads'' = if (interruptible <$> M.lookup chosen threads') /= Just False then unblockWaitingOn chosen threads' else threads' - in go idSource' sofar' (Just chosen) g' (delCommitThreads threads'') wb' + loop trcOrAct ctx' = + let (act, trc) = case trcOrAct of + Left (a, as) -> (a, (decision, runnable', a) <| as) + Right a -> (a, Seq.singleton (decision, runnable', a)) + threads' = if (interruptible <$> M.lookup chosen (cThreads ctx')) /= Just False + then unblockWaitingOn chosen (cThreads ctx') + else cThreads ctx' + sofar' = sofar <> trc + sofarSched' = sofarSched <> map (\(d,_,a) -> (d,a)) (F.toList trc) + prior' = Just (chosen, act) + in go sofar' sofarSched' prior' $ ctx' { cThreads = delCommitThreads threads' } -------------------------------------------------------------------------------- -- * Single-step execution -- | Run a single thread one step, by dispatching on the type of -- 'Action'. -stepThread :: forall n r s. MonadRef r n - => (forall x. s x -> IdSource -> n (Result x, IdSource, TTrace)) - -- ^ Run a 'MonadSTM' transaction atomically. +stepThread :: forall n r g. MonadRef r n + => Scheduler g + -- ^ The scheduler. -> MemType - -- ^ The memory model - -> Action n r s - -- ^ Action to step - -> IdSource - -- ^ Source of fresh IDs + -- ^ The memory model to use. -> ThreadId -- ^ ID of the current thread - -> Threads n r s - -- ^ Current state of threads - -> WriteBuffer r - -- ^ @CRef@ write buffer - -> Int - -- ^ The number of capabilities - -> n (Either Failure (Threads n r s, IdSource, ThreadAction, WriteBuffer r, Int)) -stepThread runstm memtype action idSource tid threads wb caps = case action of - AFork n a b -> stepFork n a b - AMyTId c -> stepMyTId c - AGetNumCapabilities c -> stepGetNumCapabilities c - ASetNumCapabilities i c -> stepSetNumCapabilities i c - AYield c -> stepYield c - ANewVar n c -> stepNewVar n c - APutVar var a c -> stepPutVar var a c - ATryPutVar var a c -> stepTryPutVar var a c - AReadVar var c -> stepReadVar var c - ATakeVar var c -> stepTakeVar var c - ATryTakeVar var c -> stepTryTakeVar var c - ANewRef n a c -> stepNewRef n a c - AReadRef ref c -> stepReadRef ref c - AReadRefCas ref c -> stepReadRefCas ref c - AModRef ref f c -> stepModRef ref f c - AModRefCas ref f c -> stepModRefCas ref f c - AWriteRef ref a c -> stepWriteRef ref a c - ACasRef ref tick a c -> stepCasRef ref tick a c - ACommit t c -> stepCommit t c - AAtom stm c -> stepAtom stm c - ALift na -> stepLift na - AThrow e -> stepThrow e - AThrowTo t e c -> stepThrowTo t e c - ACatching h ma c -> stepCatching h ma c - APopCatching a -> stepPopCatching a - AMasking m ma c -> stepMasking m ma c - AResetMask b1 b2 m c -> stepResetMask b1 b2 m c - AReturn c -> stepReturn c - AMessage m c -> stepMessage m c - AStop na -> stepStop na + -> Action n r + -- ^ Action to step + -> Context n r g + -- ^ The execution context. + -> n (Either Failure (Context n r g, Either (ThreadAction, SeqTrace) ThreadAction)) +stepThread sched memtype tid action ctx = case action of + -- start a new thread, assigning it the next 'ThreadId' + AFork n a b -> pure . Right $ + let threads' = launch tid newtid a (cThreads ctx) + (idSource', newtid) = nextTId n (cIdSource ctx) + in (ctx { cThreads = goto (b newtid) tid threads', cIdSource = idSource' }, Right (Fork newtid)) - where - -- | Start a new thread, assigning it the next 'ThreadId' - -- - -- Explicit type signature needed for GHC 8. Looks like the - -- impredicative polymorphism checks got stronger. - stepFork :: String - -> ((forall b. M n r s b -> M n r s b) -> Action n r s) - -> (ThreadId -> Action n r s) - -> n (Either Failure (Threads n r s, IdSource, ThreadAction, WriteBuffer r, Int)) - stepFork n a b = return $ Right (goto (b newtid) tid threads', idSource', Fork newtid, wb, caps) where - threads' = launch tid newtid a threads - (idSource', newtid) = nextTId n idSource + -- get the 'ThreadId' of the current thread + AMyTId c -> simple (goto (c tid) tid (cThreads ctx)) MyThreadId - -- | Get the 'ThreadId' of the current thread - stepMyTId c = simple (goto (c tid) tid threads) MyThreadId + -- get the number of capabilities + AGetNumCapabilities c -> simple (goto (c (cCaps ctx)) tid (cThreads ctx)) $ GetNumCapabilities (cCaps ctx) - -- | Get the number of capabilities - stepGetNumCapabilities c = simple (goto (c caps) tid threads) $ GetNumCapabilities caps + -- set the number of capabilities + ASetNumCapabilities i c -> pure . Right $ + (ctx { cThreads = goto c tid (cThreads ctx), cCaps = i }, Right (SetNumCapabilities i)) - -- | Set the number of capabilities - stepSetNumCapabilities i c = return $ Right (goto c tid threads, idSource, SetNumCapabilities i, wb, i) + -- yield the current thread + AYield c -> simple (goto c tid (cThreads ctx)) Yield - -- | Yield the current thread - stepYield c = simple (goto c tid threads) Yield + -- create a new @MVar@, using the next 'MVarId'. + ANewVar n c -> do + let (idSource', newmvid) = nextMVId n (cIdSource ctx) + ref <- newRef Nothing + let mvar = MVar newmvid ref + pure $ Right (ctx { cThreads = goto (c mvar) tid (cThreads ctx), cIdSource = idSource' }, Right (NewVar newmvid)) - -- | Put a value into a @MVar@, blocking the thread until it's - -- empty. - stepPutVar cvar@(MVar cvid _) a c = synchronised $ do - (success, threads', woken) <- putIntoMVar cvar a c tid threads + -- put a value into a @MVar@, blocking the thread until it's empty. + APutVar cvar@(MVar cvid _) a c -> synchronised $ do + (success, threads', woken) <- putIntoMVar cvar a c tid (cThreads ctx) simple threads' $ if success then PutVar cvid woken else BlockedPutVar cvid - -- | Try to put a value into a @MVar@, without blocking. - stepTryPutVar cvar@(MVar cvid _) a c = synchronised $ do - (success, threads', woken) <- tryPutIntoMVar cvar a c tid threads + -- try to put a value into a @MVar@, without blocking. + ATryPutVar cvar@(MVar cvid _) a c -> synchronised $ do + (success, threads', woken) <- tryPutIntoMVar cvar a c tid (cThreads ctx) simple threads' $ TryPutVar cvid success woken - -- | Get the value from a @MVar@, without emptying, blocking the + -- get the value from a @MVar@, without emptying, blocking the -- thread until it's full. - stepReadVar cvar@(MVar cvid _) c = synchronised $ do - (success, threads', _) <- readFromMVar cvar c tid threads + AReadVar cvar@(MVar cvid _) c -> synchronised $ do + (success, threads', _) <- readFromMVar cvar c tid (cThreads ctx) simple threads' $ if success then ReadVar cvid else BlockedReadVar cvid - -- | Take the value from a @MVar@, blocking the thread until it's + -- take the value from a @MVar@, blocking the thread until it's -- full. - stepTakeVar cvar@(MVar cvid _) c = synchronised $ do - (success, threads', woken) <- takeFromMVar cvar c tid threads + ATakeVar cvar@(MVar cvid _) c -> synchronised $ do + (success, threads', woken) <- takeFromMVar cvar c tid (cThreads ctx) simple threads' $ if success then TakeVar cvid woken else BlockedTakeVar cvid - -- | Try to take the value from a @MVar@, without blocking. - stepTryTakeVar cvar@(MVar cvid _) c = synchronised $ do - (success, threads', woken) <- tryTakeFromMVar cvar c tid threads + -- try to take the value from a @MVar@, without blocking. + ATryTakeVar cvar@(MVar cvid _) c -> synchronised $ do + (success, threads', woken) <- tryTakeFromMVar cvar c tid (cThreads ctx) simple threads' $ TryTakeVar cvid success woken - -- | Read from a @CRef@. - stepReadRef cref@(CRef crid _) c = do + -- create a new @CRef@, using the next 'CRefId'. + ANewRef n a c -> do + let (idSource', newcrid) = nextCRId n (cIdSource ctx) + ref <- newRef (M.empty, 0, a) + let cref = CRef newcrid ref + pure $ Right (ctx { cThreads = goto (c cref) tid (cThreads ctx), cIdSource = idSource' }, Right (NewRef newcrid)) + + -- read from a @CRef@. + AReadRef cref@(CRef crid _) c -> do val <- readCRef cref tid - simple (goto (c val) tid threads) $ ReadRef crid + simple (goto (c val) tid (cThreads ctx)) $ ReadRef crid - -- | Read from a @CRef@ for future compare-and-swap operations. - stepReadRefCas cref@(CRef crid _) c = do + -- read from a @CRef@ for future compare-and-swap operations. + AReadRefCas cref@(CRef crid _) c -> do tick <- readForTicket cref tid - simple (goto (c tick) tid threads) $ ReadRefCas crid + simple (goto (c tick) tid (cThreads ctx)) $ ReadRefCas crid - -- | Modify a @CRef@. - stepModRef cref@(CRef crid _) f c = synchronised $ do + -- modify a @CRef@. + AModRef cref@(CRef crid _) f c -> synchronised $ do (new, val) <- f <$> readCRef cref tid writeImmediate cref new - simple (goto (c val) tid threads) $ ModRef crid + simple (goto (c val) tid (cThreads ctx)) $ ModRef crid - -- | Modify a @CRef@ using a compare-and-swap. - stepModRefCas cref@(CRef crid _) f c = synchronised $ do + -- modify a @CRef@ using a compare-and-swap. + AModRefCas cref@(CRef crid _) f c -> synchronised $ do tick@(Ticket _ _ old) <- readForTicket cref tid let (new, val) = f old void $ casCRef cref tid tick new - simple (goto (c val) tid threads) $ ModRefCas crid + simple (goto (c val) tid (cThreads ctx)) $ ModRefCas crid - -- | Write to a @CRef@ without synchronising - stepWriteRef cref@(CRef crid _) a c = case memtype of - -- Write immediately. + -- write to a @CRef@ without synchronising. + AWriteRef cref@(CRef crid _) a c -> case memtype of + -- write immediately. SequentialConsistency -> do writeImmediate cref a - simple (goto c tid threads) $ WriteRef crid - - -- Add to buffer using thread id. + simple (goto c tid (cThreads ctx)) $ WriteRef crid + -- add to buffer using thread id. TotalStoreOrder -> do - wb' <- bufferWrite wb (tid, Nothing) cref a - return $ Right (goto c tid threads, idSource, WriteRef crid, wb', caps) - - -- Add to buffer using both thread id and cref id + wb' <- bufferWrite (cWriteBuf ctx) (tid, Nothing) cref a + pure $ Right (ctx { cThreads = goto c tid (cThreads ctx), cWriteBuf = wb' }, Right (WriteRef crid)) + -- add to buffer using both thread id and cref id PartialStoreOrder -> do - wb' <- bufferWrite wb (tid, Just crid) cref a - return $ Right (goto c tid threads, idSource, WriteRef crid, wb', caps) + wb' <- bufferWrite (cWriteBuf ctx) (tid, Just crid) cref a + pure $ Right (ctx { cThreads = goto c tid (cThreads ctx), cWriteBuf = wb' }, Right (WriteRef crid)) - -- | Perform a compare-and-swap on a @CRef@. - stepCasRef cref@(CRef crid _) tick a c = synchronised $ do + -- perform a compare-and-swap on a @CRef@. + ACasRef cref@(CRef crid _) tick a c -> synchronised $ do (suc, tick') <- casCRef cref tid tick a - simple (goto (c (suc, tick')) tid threads) $ CasRef crid suc + simple (goto (c (suc, tick')) tid (cThreads ctx)) $ CasRef crid suc - -- | Commit a @CRef@ write - stepCommit t c = do + -- commit a @CRef@ write + ACommit t c -> do wb' <- case memtype of - -- Shouldn't ever get here + -- shouldn't ever get here SequentialConsistency -> error "Attempting to commit under SequentialConsistency" + -- commit using the thread id. + TotalStoreOrder -> commitWrite (cWriteBuf ctx) (t, Nothing) + -- commit using the cref id. + PartialStoreOrder -> commitWrite (cWriteBuf ctx) (t, Just c) + pure $ Right (ctx { cWriteBuf = wb' }, Right (CommitRef t c)) - -- Commit using the thread id. - TotalStoreOrder -> commitWrite wb (t, Nothing) - - -- Commit using the cref id. - PartialStoreOrder -> commitWrite wb (t, Just c) - - return $ Right (threads, idSource, CommitRef t c, wb', caps) - - -- | Run a STM transaction atomically. - stepAtom stm c = synchronised $ do - (res, idSource', trace) <- runstm stm idSource + -- run a STM transaction atomically. + AAtom stm c -> synchronised $ do + (res, idSource', trace) <- runTransaction stm (cIdSource ctx) case res of Success _ written val -> - let (threads', woken) = wake (OnTVar written) threads - in return $ Right (goto (c val) tid threads', idSource', STM trace woken, wb, caps) + let (threads', woken) = wake (OnTVar written) (cThreads ctx) + in pure $ Right (ctx { cThreads = goto (c val) tid threads', cIdSource = idSource' }, Right (STM trace woken)) Retry touched -> - let threads' = block (OnTVar touched) tid threads - in return $ Right (threads', idSource', BlockedSTM trace, wb, caps) + let threads' = block (OnTVar touched) tid (cThreads ctx) + in pure $ Right (ctx { cThreads = threads', cIdSource = idSource'}, Right (BlockedSTM trace)) Exception e -> do res' <- stepThrow e - return $ case res' of - Right (threads', _, _, _, _) -> Right (threads', idSource', Throw, wb, caps) + pure $ case res' of + Right (ctx', _) -> Right (ctx' { cIdSource = idSource' }, Right Throw) Left err -> Left err - -- | Run a subcomputation in an exception-catching context. - stepCatching h ma c = simple threads' Catching where - a = runCont ma (APopCatching . c) - e exc = runCont (h exc) (APopCatching . c) + -- lift an action from the underlying monad into the @Conc@ + -- computation. + ALift na -> do + a <- na + simple (goto a tid (cThreads ctx)) LiftIO - threads' = goto a tid (catching e tid threads) - - -- | Pop the top exception handler from the thread's stack. - stepPopCatching a = simple threads' PopCatching where - threads' = goto a tid (uncatching tid threads) - - -- | Throw an exception, and propagate it to the appropriate + -- throw an exception, and propagate it to the appropriate -- handler. - stepThrow e = - case propagate (toException e) tid threads of - Just threads' -> simple threads' Throw - Nothing -> return $ Left UncaughtException + AThrow e -> stepThrow e - -- | Throw an exception to the target thread, and propagate it to + -- throw an exception to the target thread, and propagate it to -- the appropriate handler. - stepThrowTo t e c = synchronised $ - let threads' = goto c tid threads - blocked = block (OnMask t) tid threads - in case M.lookup t threads of + AThrowTo t e c -> synchronised $ + let threads' = goto c tid (cThreads ctx) + blocked = block (OnMask t) tid (cThreads ctx) + in case M.lookup t (cThreads ctx) of Just thread | interruptible thread -> case propagate (toException e) t threads' of Just threads'' -> simple threads'' $ ThrowTo t Nothing - | t == initialThread -> return $ Left UncaughtException + | t == initialThread -> pure $ Left UncaughtException | otherwise -> simple (kill t threads') $ ThrowTo t | otherwise -> simple blocked $ BlockedThrowTo t Nothing -> simple threads' $ ThrowTo t - -- | Execute a subcomputation with a new masking state, and give - -- it a function to run a computation with the current masking - -- state. - -- - -- Explicit type sig necessary for checking in the prescence of - -- 'umask', sadly. - stepMasking :: MaskingState - -> ((forall b. M n r s b -> M n r s b) -> M n r s a) - -> (a -> Action n r s) - -> n (Either Failure (Threads n r s, IdSource, ThreadAction, WriteBuffer r, Int)) - stepMasking m ma c = simple threads' $ SetMasking False m where - a = runCont (ma umask) (AResetMask False False m' . c) + -- run a subcomputation in an exception-catching context. + ACatching h ma c -> + let a = runCont ma (APopCatching . c) + e exc = runCont (h exc) (APopCatching . c) + threads' = goto a tid (catching e tid (cThreads ctx)) + in simple threads' Catching - m' = _masking . fromJust $ M.lookup tid threads - umask mb = resetMask True m' >> mb >>= \b -> resetMask False m >> return b - resetMask typ ms = cont $ \k -> AResetMask typ True ms $ k () + -- pop the top exception handler from the thread's stack. + APopCatching a -> + let threads' = goto a tid (uncatching tid (cThreads ctx)) + in simple threads' PopCatching - threads' = goto a tid (mask m tid threads) + -- execute a subcomputation with a new masking state, and give it + -- a function to run a computation with the current masking state. + AMasking m ma c -> + let a = runCont (ma umask) (AResetMask False False m' . c) + m' = _masking . fromJust $ M.lookup tid (cThreads ctx) + umask mb = resetMask True m' >> mb >>= \b -> resetMask False m >> pure b + resetMask typ ms = cont $ \k -> AResetMask typ True ms $ k () + threads' = goto a tid (mask m tid (cThreads ctx)) + in simple threads' $ SetMasking False m - -- | Reset the masking thread of the state. - stepResetMask b1 b2 m c = simple threads' act where - act = (if b1 then SetMasking else ResetMasking) b2 m - threads' = goto c tid (mask m tid threads) - -- | Create a new @MVar@, using the next 'MVarId'. - stepNewVar n c = do - let (idSource', newmvid) = nextMVId n idSource - ref <- newRef Nothing - let mvar = MVar newmvid ref - return $ Right (goto (c mvar) tid threads, idSource', NewVar newmvid, wb, caps) + -- reset the masking thread of the state. + AResetMask b1 b2 m c -> + let act = (if b1 then SetMasking else ResetMasking) b2 m + threads' = goto c tid (mask m tid (cThreads ctx)) + in simple threads' act - -- | Create a new @CRef@, using the next 'CRefId'. - stepNewRef n a c = do - let (idSource', newcrid) = nextCRId n idSource - ref <- newRef (M.empty, 0, a) - let cref = CRef newcrid ref - return $ Right (goto (c cref) tid threads, idSource', NewRef newcrid, wb, caps) + -- execute a 'return' or 'pure'. + AReturn c -> simple (goto c tid (cThreads ctx)) Return - -- | Lift an action from the underlying monad into the @Conc@ - -- computation. - stepLift na = do - a <- na - simple (goto a tid threads) LiftIO + -- add a message to the trace. + AMessage m c -> simple (goto c tid (cThreads ctx)) (Message m) - -- | Execute a 'return' or 'pure'. - stepReturn c = simple (goto c tid threads) Return + -- kill the current thread. + AStop na -> na >> simple (kill tid (cThreads ctx)) Stop - -- | Add a message to the trace. - stepMessage m c = simple (goto c tid threads) (Message m) + -- run a subconcurrent computation. + ASub ma c + | M.size (cThreads ctx) > 1 -> pure (Left IllegalSubconcurrency) + | otherwise -> do + (res, g', trace) <- runConcurrency sched memtype (cSchedState ctx) ma + pure $ Right (ctx { cThreads = goto (c res) tid (cThreads ctx), cSchedState = g' }, Left (Subconcurrency, trace)) + where - -- | Kill the current thread. - stepStop na = na >> simple (kill tid threads) Stop + -- this is not inline in the long @case@ above as it's needed by + -- @AAtom@, @AThrow@, and @AThrowTo@. + stepThrow e = + case propagate (toException e) tid (cThreads ctx) of + Just threads' -> simple threads' Throw + Nothing -> pure $ Left UncaughtException - -- | Helper for actions which don't touch the 'IdSource' or - -- 'WriteBuffer' - simple threads' act = return $ Right (threads', idSource, act, wb, caps) + -- helper for actions which only change the threads. + simple threads' act = pure $ Right (ctx { cThreads = threads' }, Right act) - -- | Helper for actions impose a write barrier. + -- helper for actions impose a write barrier. synchronised ma = do - writeBarrier wb + writeBarrier (cWriteBuf ctx) res <- ma return $ case res of - Right (threads', idSource', act', _, caps') -> Right (threads', idSource', act', emptyBuffer, caps') + Right (ctx', act) -> Right (ctx' { cWriteBuf = emptyBuffer }, act) _ -> res diff --git a/dejafu/Test/DejaFu/Conc/Internal/Common.hs b/dejafu/Test/DejaFu/Conc/Internal/Common.hs index b69b62d..aab7eae 100755 --- a/dejafu/Test/DejaFu/Conc/Internal/Common.hs +++ b/dejafu/Test/DejaFu/Conc/Internal/Common.hs @@ -18,6 +18,7 @@ import Data.Dynamic (Dynamic) import Data.Map.Strict (Map) import Data.List.NonEmpty (NonEmpty, fromList) import Test.DejaFu.Common +import Test.DejaFu.STM (STMLike) {-# ANN module ("HLint: ignore Use record patterns" :: String) #-} @@ -32,16 +33,16 @@ import Test.DejaFu.Common -- current expression of threads and exception handlers very difficult -- (perhaps even not possible without significant reworking), so I -- abandoned the attempt. -newtype M n r s a = M { runM :: (a -> Action n r s) -> Action n r s } +newtype M n r a = M { runM :: (a -> Action n r) -> Action n r } -instance Functor (M n r s) where +instance Functor (M n r) where fmap f m = M $ \ c -> runM m (c . f) -instance Applicative (M n r s) where +instance Applicative (M n r) where pure x = M $ \c -> AReturn $ c x f <*> v = M $ \c -> runM f (\g -> runM v (c . g)) -instance Monad (M n r s) where +instance Monad (M n r) where return = pure m >>= k = M $ \c -> runM m (\x -> runM (k x) c) @@ -82,11 +83,11 @@ data Ticket a = Ticket } -- | Construct a continuation-passing operation from a function. -cont :: ((a -> Action n r s) -> Action n r s) -> M n r s a +cont :: ((a -> Action n r) -> Action n r) -> M n r a cont = M -- | Run a CPS computation with the given final computation. -runCont :: M n r s a -> (a -> Action n r s) -> Action n r s +runCont :: M n r a -> (a -> Action n r) -> Action n r runCont = runM -------------------------------------------------------------------------------- @@ -96,49 +97,51 @@ runCont = runM -- only occur as a result of an action, and they cover (most of) the -- primitives of the concurrency. 'spawn' is absent as it is -- implemented in terms of 'newEmptyMVar', 'fork', and 'putMVar'. -data Action n r s = - AFork String ((forall b. M n r s b -> M n r s b) -> Action n r s) (ThreadId -> Action n r s) - | AMyTId (ThreadId -> Action n r s) +data Action n r = + AFork String ((forall b. M n r b -> M n r b) -> Action n r) (ThreadId -> Action n r) + | AMyTId (ThreadId -> Action n r) - | AGetNumCapabilities (Int -> Action n r s) - | ASetNumCapabilities Int (Action n r s) + | AGetNumCapabilities (Int -> Action n r) + | ASetNumCapabilities Int (Action n r) - | forall a. ANewVar String (MVar r a -> Action n r s) - | forall a. APutVar (MVar r a) a (Action n r s) - | forall a. ATryPutVar (MVar r a) a (Bool -> Action n r s) - | forall a. AReadVar (MVar r a) (a -> Action n r s) - | forall a. ATakeVar (MVar r a) (a -> Action n r s) - | forall a. ATryTakeVar (MVar r a) (Maybe a -> Action n r s) + | forall a. ANewVar String (MVar r a -> Action n r) + | forall a. APutVar (MVar r a) a (Action n r) + | forall a. ATryPutVar (MVar r a) a (Bool -> Action n r) + | forall a. AReadVar (MVar r a) (a -> Action n r) + | forall a. ATakeVar (MVar r a) (a -> Action n r) + | forall a. ATryTakeVar (MVar r a) (Maybe a -> Action n r) - | forall a. ANewRef String a (CRef r a -> Action n r s) - | forall a. AReadRef (CRef r a) (a -> Action n r s) - | forall a. AReadRefCas (CRef r a) (Ticket a -> Action n r s) - | forall a b. AModRef (CRef r a) (a -> (a, b)) (b -> Action n r s) - | forall a b. AModRefCas (CRef r a) (a -> (a, b)) (b -> Action n r s) - | forall a. AWriteRef (CRef r a) a (Action n r s) - | forall a. ACasRef (CRef r a) (Ticket a) a ((Bool, Ticket a) -> Action n r s) + | forall a. ANewRef String a (CRef r a -> Action n r) + | forall a. AReadRef (CRef r a) (a -> Action n r) + | forall a. AReadRefCas (CRef r a) (Ticket a -> Action n r) + | forall a b. AModRef (CRef r a) (a -> (a, b)) (b -> Action n r) + | forall a b. AModRefCas (CRef r a) (a -> (a, b)) (b -> Action n r) + | forall a. AWriteRef (CRef r a) a (Action n r) + | forall a. ACasRef (CRef r a) (Ticket a) a ((Bool, Ticket a) -> Action n r) | forall e. Exception e => AThrow e - | forall e. Exception e => AThrowTo ThreadId e (Action n r s) - | forall a e. Exception e => ACatching (e -> M n r s a) (M n r s a) (a -> Action n r s) - | APopCatching (Action n r s) - | forall a. AMasking MaskingState ((forall b. M n r s b -> M n r s b) -> M n r s a) (a -> Action n r s) - | AResetMask Bool Bool MaskingState (Action n r s) + | forall e. Exception e => AThrowTo ThreadId e (Action n r) + | forall a e. Exception e => ACatching (e -> M n r a) (M n r a) (a -> Action n r) + | APopCatching (Action n r) + | forall a. AMasking MaskingState ((forall b. M n r b -> M n r b) -> M n r a) (a -> Action n r) + | AResetMask Bool Bool MaskingState (Action n r) - | AMessage Dynamic (Action n r s) + | AMessage Dynamic (Action n r) - | forall a. AAtom (s a) (a -> Action n r s) - | ALift (n (Action n r s)) - | AYield (Action n r s) - | AReturn (Action n r s) + | forall a. AAtom (STMLike n r a) (a -> Action n r) + | ALift (n (Action n r)) + | AYield (Action n r) + | AReturn (Action n r) | ACommit ThreadId CRefId | AStop (n ()) + | forall a. ASub (M n r a) (Either Failure a -> Action n r) + -------------------------------------------------------------------------------- -- * Scheduling & Traces -- | Look as far ahead in the given continuation as possible. -lookahead :: Action n r s -> NonEmpty Lookahead +lookahead :: Action n r -> NonEmpty Lookahead lookahead = fromList . lookahead' where lookahead' (AFork _ _ _) = [WillFork] lookahead' (AMyTId _) = [WillMyThreadId] @@ -170,3 +173,4 @@ lookahead = fromList . lookahead' where lookahead' (AYield k) = WillYield : lookahead' k lookahead' (AReturn k) = WillReturn : lookahead' k lookahead' (AStop _) = [WillStop] + lookahead' (ASub _ _) = [WillSubconcurrency] diff --git a/dejafu/Test/DejaFu/Conc/Internal/Memory.hs b/dejafu/Test/DejaFu/Conc/Internal/Memory.hs index 2165cad..a63a763 100755 --- a/dejafu/Test/DejaFu/Conc/Internal/Memory.hs +++ b/dejafu/Test/DejaFu/Conc/Internal/Memory.hs @@ -126,7 +126,7 @@ writeBarrier (WriteBuffer wb) = mapM_ flush $ M.elems wb where flush = mapM_ $ \(BufferedWrite _ cref a) -> writeImmediate cref a -- | Add phantom threads to the thread list to commit pending writes. -addCommitThreads :: WriteBuffer r -> Threads n r s -> Threads n r s +addCommitThreads :: WriteBuffer r -> Threads n r -> Threads n r addCommitThreads (WriteBuffer wb) ts = ts <> M.fromList phantoms where phantoms = [ (ThreadId Nothing $ negate tid, mkthread $ fromJust c) | ((k, b), tid) <- zip (M.toList wb) [1..] @@ -136,41 +136,41 @@ addCommitThreads (WriteBuffer wb) ts = ts <> M.fromList phantoms where go EmptyL = Nothing -- | Remove phantom threads. -delCommitThreads :: Threads n r s -> Threads n r s +delCommitThreads :: Threads n r -> Threads n r delCommitThreads = M.filterWithKey $ \k _ -> k >= initialThread -------------------------------------------------------------------------------- -- * Manipulating @MVar@s -- | Put into a @MVar@, blocking if full. -putIntoMVar :: MonadRef r n => MVar r a -> a -> Action n r s - -> ThreadId -> Threads n r s -> n (Bool, Threads n r s, [ThreadId]) +putIntoMVar :: MonadRef r n => MVar r a -> a -> Action n r + -> ThreadId -> Threads n r -> n (Bool, Threads n r, [ThreadId]) putIntoMVar cvar a c = mutMVar True cvar a (const c) -- | Try to put into a @MVar@, not blocking if full. -tryPutIntoMVar :: MonadRef r n => MVar r a -> a -> (Bool -> Action n r s) - -> ThreadId -> Threads n r s -> n (Bool, Threads n r s, [ThreadId]) +tryPutIntoMVar :: MonadRef r n => MVar r a -> a -> (Bool -> Action n r) + -> ThreadId -> Threads n r -> n (Bool, Threads n r, [ThreadId]) tryPutIntoMVar = mutMVar False -- | Read from a @MVar@, blocking if empty. -readFromMVar :: MonadRef r n => MVar r a -> (a -> Action n r s) - -> ThreadId -> Threads n r s -> n (Bool, Threads n r s, [ThreadId]) +readFromMVar :: MonadRef r n => MVar r a -> (a -> Action n r) + -> ThreadId -> Threads n r -> n (Bool, Threads n r, [ThreadId]) readFromMVar cvar c = seeMVar False True cvar (c . fromJust) -- | Take from a @MVar@, blocking if empty. -takeFromMVar :: MonadRef r n => MVar r a -> (a -> Action n r s) - -> ThreadId -> Threads n r s -> n (Bool, Threads n r s, [ThreadId]) +takeFromMVar :: MonadRef r n => MVar r a -> (a -> Action n r) + -> ThreadId -> Threads n r -> n (Bool, Threads n r, [ThreadId]) takeFromMVar cvar c = seeMVar True True cvar (c . fromJust) -- | Try to take from a @MVar@, not blocking if empty. -tryTakeFromMVar :: MonadRef r n => MVar r a -> (Maybe a -> Action n r s) - -> ThreadId -> Threads n r s -> n (Bool, Threads n r s, [ThreadId]) +tryTakeFromMVar :: MonadRef r n => MVar r a -> (Maybe a -> Action n r) + -> ThreadId -> Threads n r -> n (Bool, Threads n r, [ThreadId]) tryTakeFromMVar = seeMVar True False -- | Mutate a @MVar@, in either a blocking or nonblocking way. mutMVar :: MonadRef r n - => Bool -> MVar r a -> a -> (Bool -> Action n r s) - -> ThreadId -> Threads n r s -> n (Bool, Threads n r s, [ThreadId]) + => Bool -> MVar r a -> a -> (Bool -> Action n r) + -> ThreadId -> Threads n r -> n (Bool, Threads n r, [ThreadId]) mutMVar blocking (MVar cvid ref) a c threadid threads = do val <- readRef ref @@ -191,8 +191,8 @@ mutMVar blocking (MVar cvid ref) a c threadid threads = do -- | Read a @MVar@, in either a blocking or nonblocking -- way. seeMVar :: MonadRef r n - => Bool -> Bool -> MVar r a -> (Maybe a -> Action n r s) - -> ThreadId -> Threads n r s -> n (Bool, Threads n r s, [ThreadId]) + => Bool -> Bool -> MVar r a -> (Maybe a -> Action n r) + -> ThreadId -> Threads n r -> n (Bool, Threads n r, [ThreadId]) seeMVar emptying blocking (MVar cvid ref) c threadid threads = do val <- readRef ref diff --git a/dejafu/Test/DejaFu/Conc/Internal/Threading.hs b/dejafu/Test/DejaFu/Conc/Internal/Threading.hs index 63925d6..2302a67 100644 --- a/dejafu/Test/DejaFu/Conc/Internal/Threading.hs +++ b/dejafu/Test/DejaFu/Conc/Internal/Threading.hs @@ -27,22 +27,22 @@ import qualified Data.Map.Strict as M -- * Threads -- | Threads are stored in a map index by 'ThreadId'. -type Threads n r s = Map ThreadId (Thread n r s) +type Threads n r = Map ThreadId (Thread n r) -- | All the state of a thread. -data Thread n r s = Thread - { _continuation :: Action n r s +data Thread n r = Thread + { _continuation :: Action n r -- ^ The next action to execute. , _blocking :: Maybe BlockedOn -- ^ The state of any blocks. - , _handlers :: [Handler n r s] + , _handlers :: [Handler n r] -- ^ Stack of exception handlers , _masking :: MaskingState -- ^ The exception masking state. } -- | Construct a thread with just one action -mkthread :: Action n r s -> Thread n r s +mkthread :: Action n r -> Thread n r mkthread c = Thread c Nothing [] Unmasked -------------------------------------------------------------------------------- @@ -53,7 +53,7 @@ mkthread c = Thread c Nothing [] Unmasked data BlockedOn = OnMVarFull MVarId | OnMVarEmpty MVarId | OnTVar [TVarId] | OnMask ThreadId deriving Eq -- | Determine if a thread is blocked in a certain way. -(~=) :: Thread n r s -> BlockedOn -> Bool +(~=) :: Thread n r -> BlockedOn -> Bool thread ~= theblock = case (_blocking thread, theblock) of (Just (OnMVarFull _), OnMVarFull _) -> True (Just (OnMVarEmpty _), OnMVarEmpty _) -> True @@ -65,11 +65,11 @@ thread ~= theblock = case (_blocking thread, theblock) of -- * Exceptions -- | An exception handler. -data Handler n r s = forall e. Exception e => Handler (e -> Action n r s) +data Handler n r = forall e. Exception e => Handler (e -> Action n r) -- | Propagate an exception upwards, finding the closest handler -- which can deal with it. -propagate :: SomeException -> ThreadId -> Threads n r s -> Maybe (Threads n r s) +propagate :: SomeException -> ThreadId -> Threads n r -> Maybe (Threads n r) propagate e tid threads = case M.lookup tid threads >>= go . _handlers of Just (act, hs) -> Just $ except act hs tid threads Nothing -> Nothing @@ -79,40 +79,40 @@ propagate e tid threads = case M.lookup tid threads >>= go . _handlers of go (Handler h:hs) = maybe (go hs) (\act -> Just (act, hs)) $ h <$> fromException e -- | Check if a thread can be interrupted by an exception. -interruptible :: Thread n r s -> Bool +interruptible :: Thread n r -> Bool interruptible thread = _masking thread == Unmasked || (_masking thread == MaskedInterruptible && isJust (_blocking thread)) -- | Register a new exception handler. -catching :: Exception e => (e -> Action n r s) -> ThreadId -> Threads n r s -> Threads n r s +catching :: Exception e => (e -> Action n r) -> ThreadId -> Threads n r -> Threads n r catching h = M.alter $ \(Just thread) -> Just $ thread { _handlers = Handler h : _handlers thread } -- | Remove the most recent exception handler. -uncatching :: ThreadId -> Threads n r s -> Threads n r s +uncatching :: ThreadId -> Threads n r -> Threads n r uncatching = M.alter $ \(Just thread) -> Just $ thread { _handlers = tail $ _handlers thread } -- | Raise an exception in a thread. -except :: Action n r s -> [Handler n r s] -> ThreadId -> Threads n r s -> Threads n r s +except :: Action n r -> [Handler n r] -> ThreadId -> Threads n r -> Threads n r except act hs = M.alter $ \(Just thread) -> Just $ thread { _continuation = act, _handlers = hs, _blocking = Nothing } -- | Set the masking state of a thread. -mask :: MaskingState -> ThreadId -> Threads n r s -> Threads n r s +mask :: MaskingState -> ThreadId -> Threads n r -> Threads n r mask ms = M.alter $ \(Just thread) -> Just $ thread { _masking = ms } -------------------------------------------------------------------------------- -- * Manipulating threads -- | Replace the @Action@ of a thread. -goto :: Action n r s -> ThreadId -> Threads n r s -> Threads n r s +goto :: Action n r -> ThreadId -> Threads n r -> Threads n r goto a = M.alter $ \(Just thread) -> Just (thread { _continuation = a }) -- | Start a thread with the given ID, inheriting the masking state -- from the parent thread. This ID must not already be in use! -launch :: ThreadId -> ThreadId -> ((forall b. M n r s b -> M n r s b) -> Action n r s) -> Threads n r s -> Threads n r s +launch :: ThreadId -> ThreadId -> ((forall b. M n r b -> M n r b) -> Action n r) -> Threads n r -> Threads n r launch parent tid a threads = launch' ms tid a threads where ms = fromMaybe Unmasked $ _masking <$> M.lookup parent threads -- | Start a thread with the given ID and masking state. This must not already be in use! -launch' :: MaskingState -> ThreadId -> ((forall b. M n r s b -> M n r s b) -> Action n r s) -> Threads n r s -> Threads n r s +launch' :: MaskingState -> ThreadId -> ((forall b. M n r b -> M n r b) -> Action n r) -> Threads n r -> Threads n r launch' ms tid a = M.insert tid thread where thread = Thread { _continuation = a umask, _blocking = Nothing, _handlers = [], _masking = ms } @@ -120,11 +120,11 @@ launch' ms tid a = M.insert tid thread where resetMask typ m = cont $ \k -> AResetMask typ True m $ k () -- | Kill a thread. -kill :: ThreadId -> Threads n r s -> Threads n r s +kill :: ThreadId -> Threads n r -> Threads n r kill = M.delete -- | Block a thread. -block :: BlockedOn -> ThreadId -> Threads n r s -> Threads n r s +block :: BlockedOn -> ThreadId -> Threads n r -> Threads n r block blockedOn = M.alter doBlock where doBlock (Just thread) = Just $ thread { _blocking = Just blockedOn } doBlock _ = error "Invariant failure in 'block': thread does NOT exist!" @@ -132,7 +132,7 @@ block blockedOn = M.alter doBlock where -- | Unblock all threads waiting on the appropriate block. For 'TVar' -- blocks, this will wake all threads waiting on at least one of the -- given 'TVar's. -wake :: BlockedOn -> Threads n r s -> (Threads n r s, [ThreadId]) +wake :: BlockedOn -> Threads n r -> (Threads n r, [ThreadId]) wake blockedOn threads = (unblock <$> threads, M.keys $ M.filter isBlocked threads) where unblock thread | isBlocked thread = thread { _blocking = Nothing } diff --git a/dejafu/Test/DejaFu/SCT.hs b/dejafu/Test/DejaFu/SCT.hs index dfdd3e7..b3bd510 100755 --- a/dejafu/Test/DejaFu/SCT.hs +++ b/dejafu/Test/DejaFu/SCT.hs @@ -79,9 +79,8 @@ module Test.DejaFu.SCT , sctLengthBound ) where -import Control.DeepSeq (NFData(..)) import Control.Monad.Ref (MonadRef) -import Data.List (nub) +import Data.List (foldl') import qualified Data.Map.Strict as M import Data.Maybe (isJust, fromJust) import qualified Data.Set as S @@ -140,17 +139,16 @@ cBound (Bounds pb fb lb) = -- -- If no bounds are enabled, just backtrack to the given point. cBacktrack :: Bounds -> BacktrackFunc -cBacktrack (Bounds Nothing Nothing Nothing) bs i t = backtrackAt (const False) False bs i t -cBacktrack (Bounds pb fb lb) bs i t = lBack . fBack $ pBack bs where - pBack backs = if isJust pb then pBacktrack backs i t else backs - fBack backs = if isJust fb then fBacktrack backs i t else backs - lBack backs = if isJust lb then lBacktrack backs i t else backs +cBacktrack (Bounds (Just _) _ _) = pBacktrack +cBacktrack (Bounds _ (Just _) _) = fBacktrack +cBacktrack (Bounds _ _ (Just _)) = lBacktrack +cBacktrack _ = backtrackAt (\_ _ -> False) ------------------------------------------------------------------------------- -- Pre-emption bounding newtype PreemptionBound = PreemptionBound Int - deriving (NFData, Enum, Eq, Ord, Num, Real, Integral, Read, Show) + deriving (Enum, Eq, Ord, Num, Real, Integral, Read, Show) -- | A sensible default preemption bound: 2. -- @@ -181,31 +179,26 @@ pBound (PreemptionBound pb) ts dl = preEmpCount ts dl <= pb -- the same state being reached multiple times, but is needed because -- of the artificial dependency imposed by the bound. pBacktrack :: BacktrackFunc -pBacktrack bs i tid = - maybe id (\j' b -> backtrack True b j' tid) j $ backtrack False bs i tid +pBacktrack bs = backtrackAt (\_ _ -> False) bs . concatMap addConservative where + addConservative o@(i, _, tid) = o : case conservative i of + Just j -> [(j, True, tid)] + Nothing -> [] - where - -- Index of the conservative point - j = goJ . reverse . pairs $ zip [0..i-1] bs where - goJ (((_,b1), (j',b2)):rest) - | bcktThreadid b1 /= bcktThreadid b2 - && not (isCommitRef . snd $ bcktDecision b1) - && not (isCommitRef . snd $ bcktDecision b2) = Just j' - | otherwise = goJ rest - goJ [] = Nothing - - -- List of adjacent pairs - {-# INLINE pairs #-} - pairs = zip <*> tail - - -- Add a backtracking point. - backtrack = backtrackAt $ const False + -- index of conservative point + conservative i = go (reverse (take (i-1) bs)) (i-1) where + go _ (-1) = Nothing + go (b1:rest@(b2:_)) j + | bcktThreadid b1 /= bcktThreadid b2 + && not (isCommitRef $ bcktAction b1) + && not (isCommitRef $ bcktAction b2) = Just j + | otherwise = go rest (j-1) + go _ _ = Nothing ------------------------------------------------------------------------------- -- Fair bounding newtype FairBound = FairBound Int - deriving (NFData, Enum, Eq, Ord, Num, Real, Integral, Read, Show) + deriving (Enum, Eq, Ord, Num, Real, Integral, Read, Show) -- | A sensible default fair bound: 5. -- @@ -233,15 +226,15 @@ fBound (FairBound fb) ts (_, l) = maxYieldCountDiff ts l <= fb -- | Add a backtrack point. If the thread isn't runnable, or performs -- a release operation, add all runnable threads. fBacktrack :: BacktrackFunc -fBacktrack bs i t = backtrackAt check False bs i t where +fBacktrack = backtrackAt check where -- True if a release operation is performed. - check b = Just True == (willRelease <$> M.lookup t (bcktRunnable b)) + check t b = Just True == (willRelease <$> M.lookup t (bcktRunnable b)) ------------------------------------------------------------------------------- -- Length bounding newtype LengthBound = LengthBound Int - deriving (NFData, Enum, Eq, Ord, Num, Real, Integral, Read, Show) + deriving (Enum, Eq, Ord, Num, Real, Integral, Read, Show) -- | A sensible default length bound: 250. -- @@ -269,7 +262,7 @@ lBound (LengthBound lb) ts _ = length ts < lb -- | Add a backtrack point. If the thread isn't runnable, add all -- runnable threads. lBacktrack :: BacktrackFunc -lBacktrack = backtrackAt (const False) False +lBacktrack = backtrackAt (\_ _ -> False) ------------------------------------------------------------------------------- -- DPOR @@ -313,7 +306,7 @@ sctBounded memtype bf backtrack conc = go initialState where if schedIgnore s then go newDPOR - else ((res, trace):) <$> go (pruneCommits $ addBacktracks bpoints newDPOR) + else ((res, trace):) <$> go (addBacktracks bpoints newDPOR) Nothing -> pure [] @@ -332,32 +325,11 @@ sctBounded memtype bf backtrack conc = go initialState where -- Incorporate the new backtracking steps into the DPOR tree. addBacktracks = incorporateBacktrackSteps bf -------------------------------------------------------------------------------- --- Post-processing - --- | Remove commits from the todo sets where every other action will --- result in a write barrier (and so a commit) occurring. --- --- To get the benefit from this, do not execute commit actions from --- the todo set until there are no other choises. -pruneCommits :: DPOR -> DPOR -pruneCommits bpor - | not onlycommits || not alldonesync = go bpor - | otherwise = go bpor { dporTodo = M.empty } - - where - go b = b { dporDone = pruneCommits <$> dporDone bpor } - - onlycommits = all ( DepState -> (ThreadId, ThreadAction) -> (ThreadId, ThreadAction) -> Bool +dependent :: MemType -> DepState -> ThreadId -> ThreadAction -> ThreadId -> ThreadAction -> Bool -- This is basically the same as 'dependent'', but can make use of the -- additional information in a 'ThreadAction' to make different -- decisions in a few cases: @@ -381,14 +353,14 @@ dependent :: MemType -> DepState -> (ThreadId, ThreadAction) -> (ThreadId, Threa -- - Dependency of STM transactions can be /greatly/ improved here, -- as the 'Lookahead' does not know which @TVar@s will be touched, -- and so has to assume all transactions are dependent. -dependent _ _ (_, SetNumCapabilities a) (_, GetNumCapabilities b) = a /= b -dependent _ ds (_, ThrowTo t) (t2, a) = t == t2 && canInterrupt ds t2 a -dependent memtype ds (t1, a1) (t2, a2) = case rewind a2 of +dependent _ _ _ (SetNumCapabilities a) _ (GetNumCapabilities b) = a /= b +dependent _ ds _ (ThrowTo t) t2 a = t == t2 && canInterrupt ds t2 a +dependent memtype ds t1 a1 t2 a2 = case rewind a2 of Just l2 | isSTM a1 && isSTM a2 -> not . S.null $ tvarsOf a1 `S.intersection` tvarsOf a2 | not (isBlock a1 && isBarrier (simplifyLookahead l2)) -> - dependent' memtype ds (t1, a1) (t2, l2) + dependent' memtype ds t1 a1 t2 l2 _ -> dependentActions memtype ds (simplifyAction a1) (simplifyAction a2) where @@ -400,8 +372,8 @@ dependent memtype ds (t1, a1) (t2, a2) = case rewind a2 of -- -- Termination of the initial thread is handled specially in the DPOR -- implementation. -dependent' :: MemType -> DepState -> (ThreadId, ThreadAction) -> (ThreadId, Lookahead) -> Bool -dependent' memtype ds (t1, a1) (t2, l2) = case (a1, l2) of +dependent' :: MemType -> DepState -> ThreadId -> ThreadAction -> ThreadId -> Lookahead -> Bool +dependent' memtype ds t1 a1 t2 l2 = case (a1, l2) of -- Worst-case assumption: all IO is dependent. (LiftIO, WillLiftIO) -> True @@ -496,6 +468,7 @@ yieldCount tid ts l = go initialThread ts where | t == tid && willYield l = 1 | otherwise = 0 + {-# INLINE go' #-} go' t t' act rest | t == tid && didYield act = 1 + go t' rest | otherwise = go t' rest @@ -505,10 +478,14 @@ yieldCount tid ts l = go initialThread ts where maxYieldCountDiff :: [(Decision, ThreadAction)] -> Lookahead -> Int -maxYieldCountDiff ts l = maximum yieldCountDiffs where - yieldsBy tid = yieldCount tid ts l - yieldCounts = [yieldsBy tid | tid <- nub $ allTids ts] - yieldCountDiffs = [y1 - y2 | y1 <- yieldCounts, y2 <- yieldCounts] +maxYieldCountDiff ts l = go 0 yieldCounts where + go m (yc:ycs) = + let m' = m `max` foldl' (go' yc) 0 ycs + in go m' ycs + go m [] = m + go' yc0 m yc = m `max` abs (yc0 - yc) + + yieldCounts = [yieldCount t ts l | t <- allTids ts] -- All the threads created during the lifetime of the system. allTids ((_, act):rest) = diff --git a/dejafu/Test/DejaFu/SCT/Internal.hs b/dejafu/Test/DejaFu/SCT/Internal.hs index 657c020..3696cd1 100644 --- a/dejafu/Test/DejaFu/SCT/Internal.hs +++ b/dejafu/Test/DejaFu/SCT/Internal.hs @@ -11,14 +11,14 @@ -- interface of this library. module Test.DejaFu.SCT.Internal where -import Control.DeepSeq (NFData(..), force) import Control.Exception (MaskingState(..)) import Data.Char (ord) -import Data.List (foldl', intercalate, partition, sortBy) +import Data.Function (on) +import qualified Data.Foldable as F +import Data.List (intercalate, nubBy, partition, sortOn) import Data.List.NonEmpty (NonEmpty(..), toList) -import Data.Ord (Down(..), comparing) import Data.Map.Strict (Map) -import Data.Maybe (fromJust, isJust, isNothing, mapMaybe) +import Data.Maybe (catMaybes, fromJust, isNothing) import qualified Data.Map.Strict as M import Data.Set (Set) import qualified Data.Set as S @@ -51,16 +51,7 @@ data DPOR = DPOR , dporAction :: Maybe ThreadAction -- ^ What happened at this step. This will be 'Nothing' at the root, -- 'Just' everywhere else. - } - -instance NFData DPOR where - rnf dpor = rnf ( dporRunnable dpor - , dporTodo dpor - , dporDone dpor - , dporSleep dpor - , dporTaken dpor - , dporAction dpor - ) + } deriving Show -- | One step of the execution, including information for backtracking -- purposes. This backtracking information is used to generate new @@ -68,7 +59,9 @@ instance NFData DPOR where data BacktrackStep = BacktrackStep { bcktThreadid :: ThreadId -- ^ The thread running at this step - , bcktDecision :: (Decision, ThreadAction) + , bcktDecision :: Decision + -- ^ What was decided at this step. + , bcktAction :: ThreadAction -- ^ What happened at this step. , bcktRunnable :: Map ThreadId Lookahead -- ^ The threads runnable at this step @@ -77,15 +70,7 @@ data BacktrackStep = BacktrackStep -- alternatives were added conservatively due to the bound. , bcktState :: DepState -- ^ Some domain-specific state at this point. - } - -instance NFData BacktrackStep where - rnf b = rnf ( bcktThreadid b - , bcktDecision b - , bcktRunnable b - , bcktBacktracks b - , bcktState b - ) + } deriving Show -- | Initial DPOR state, given an initial thread ID. This initial -- thread should exist and be runnable at the start of execution. @@ -127,24 +112,17 @@ findSchedulePrefix predicate idx dpor0 (ts, c, slp) = allPrefixes !! i in Just (ts, c, slp, g) where - allPrefixes = go (initialDPORThread dpor0) dpor0 + allPrefixes = go dpor0 - go tid dpor = + go dpor = -- All the possible prefix traces from this point, with -- updated DPOR subtrees if taken from the done list. - let prefixes = concatMap go' (M.toList $ dporDone dpor) ++ here dpor - -- Sort by number of preemptions, in descending order. - cmp = Down . preEmps tid dpor . (\(a,_,_) -> a) - sorted = sortBy (comparing cmp) prefixes + let prefixes = here dpor : map go' (M.toList $ dporDone dpor) + in case concatPartition (\(t:_,_,_) -> predicate t) prefixes of + ([], choices) -> choices + (choices, _) -> choices - in if null prefixes - then [] - else case partition (\(t:_,_,_) -> predicate t) sorted of - ([], []) -> err "findSchedulePrefix" "empty prefix list!" - ([], choices) -> choices - (choices, _) -> choices - - go' (tid, dpor) = (\(ts,c,slp) -> (tid:ts,c,slp)) <$> go tid dpor + go' (tid, dpor) = (\(ts,c,slp) -> (tid:ts,c,slp)) <$> go dpor -- Prefix traces terminating with a to-do decision at this point. here dpor = [([t], c, sleeps dpor) | (t, c) <- M.toList $ dporTodo dpor] @@ -154,16 +132,10 @@ findSchedulePrefix predicate idx dpor0 -- explored. sleeps dpor = dporSleep dpor `M.union` dporTaken dpor - -- The number of pre-emptive context switches - preEmps tid dpor (t:ts) = - let rest = preEmps t (fromJust . M.lookup t $ dporDone dpor) ts - in if tid `S.member` dporRunnable dpor then 1 + rest else rest - preEmps _ _ [] = 0::Int - -- | Add a new trace to the tree, creating a new subtree branching off -- at the point where the \"to-do\" decision was made. incorporateTrace - :: (DepState -> (ThreadId, ThreadAction) -> (ThreadId, ThreadAction) -> Bool) + :: (DepState -> ThreadId -> ThreadAction -> ThreadId -> ThreadAction -> Bool) -- ^ Dependency function -> Bool -- ^ Whether the \"to-do\" point which was used to create this new @@ -176,7 +148,7 @@ incorporateTrace incorporateTrace dependency conservative trace dpor0 = grow initialDepState (initialDPORThread dpor0) trace dpor0 where grow state tid trc@((d, _, a):rest) dpor = let tid' = tidOf tid d - state' = updateDepState state (tid', a) + state' = updateDepState state tid' a in case M.lookup tid' (dporDone dpor) of Just dpor' -> let done = M.insert tid' (grow state' tid' rest dpor') (dporDone dpor) @@ -193,8 +165,8 @@ incorporateTrace dependency conservative trace dpor0 = grow initialDepState (ini -- Construct a new subtree corresponding to a trace suffix. subtree state tid sleep ((_, _, a):rest) = - let state' = updateDepState state (tid, a) - sleep' = M.filterWithKey (\t a' -> not $ dependency state' (tid, a) (t,a')) sleep + let state' = updateDepState state tid a + sleep' = M.filterWithKey (\t a' -> not $ dependency state' tid a t a') sleep in DPOR { dporRunnable = S.fromList $ case rest of ((_, runnable, _):_) -> map fst runnable @@ -225,7 +197,7 @@ incorporateTrace dependency conservative trace dpor0 = grow initialDepState (ini -- runnable, a dependency is imposed between this final action and -- everything else. findBacktrackSteps - :: (DepState -> (ThreadId, ThreadAction) -> (ThreadId, Lookahead) -> Bool) + :: (DepState -> ThreadId -> ThreadAction -> ThreadId -> Lookahead -> Bool) -- ^ Dependency function. -> BacktrackFunc -- ^ Backtracking function. Given a list of backtracking points, and @@ -244,17 +216,16 @@ findBacktrackSteps -> Trace -- ^ The execution trace. -> [BacktrackStep] -findBacktrackSteps _ _ _ bcktrck - | Sq.null bcktrck = const [] -findBacktrackSteps dependency backtrack boundKill bcktrck = go initialDepState S.empty initialThread [] (Sq.viewl bcktrck) where +findBacktrackSteps dependency backtrack boundKill = go initialDepState S.empty initialThread [] . F.toList where -- Walk through the traces one step at a time, building up a list of -- new backtracking points. - go state allThreads tid bs ((e,i): (i', False)) i , bcktState = state' @@ -263,30 +234,41 @@ findBacktrackSteps dependency backtrack boundKill bcktrck = go initialDepState S runnable = S.fromList (M.keys $ bcktRunnable this) allThreads' = allThreads `S.union` runnable killsEarly = null ts && boundKill - in go state' allThreads' tid' bs' (Sq.viewl is) ts + in go state' allThreads' tid' bs' is ts go _ _ _ bs _ _ = bs -- Find the prior actions dependent with this one and add -- backtracking points. doBacktrack killsEarly allThreads enabledThreads bs = let tagged = reverse $ zip [0..] bs - idxs = [ (head is, u) + idxs = [ (head is, False, u) | (u, n) <- enabledThreads , v <- S.toList allThreads , u /= v , let is = idxs' u n v tagged , not $ null is] - idxs' u n v = mapMaybe go' where - go' (i, b) + idxs' u n v = catMaybes . go' True where + {-# INLINE go' #-} + go' final ((i,b):rest) + -- Don't cross subconcurrency boundaries + | isSubC final b = [] -- If this is the final action in the trace and the -- execution was killed due to nothing being within bounds -- (@killsEarly == True@) assume worst-case dependency. - | bcktThreadid b == v && (killsEarly || isDependent b) = Just i - | otherwise = Nothing + | bcktThreadid b == v && (killsEarly || isDependent b) = Just i : go' False rest + | otherwise = go' False rest + go' _ [] = [] - isDependent b = dependency (bcktState b) (bcktThreadid b, snd $ bcktDecision b) (u, n) - in foldl' (\b (i, u) -> backtrack b i u) bs idxs + {-# INLINE isSubC #-} + isSubC final b = case bcktAction b of + Stop -> not final && bcktThreadid b == initialThread + Subconcurrency -> bcktThreadid b == initialThread + _ -> False + + {-# INLINE isDependent #-} + isDependent b = dependency (bcktState b) (bcktThreadid b) (bcktAction b) u n + in backtrack bs idxs -- | Add new backtracking points, if they have not already been -- visited, fit into the bound, and aren't in the sleep set. @@ -302,10 +284,9 @@ incorporateBacktrackSteps bv = go Nothing [] where go priorTid pref (b:bs) bpor = let bpor' = doBacktrack priorTid pref b bpor tid = bcktThreadid b - pref' = pref ++ [bcktDecision b] + pref' = pref ++ [(bcktDecision b, bcktAction b)] child = go (Just tid) pref' bs . fromJust $ M.lookup tid (dporDone bpor) in bpor' { dporDone = M.insert tid child $ dporDone bpor' } - go _ _ [] bpor = bpor doBacktrack priorTid pref b bpor = @@ -343,16 +324,7 @@ data SchedState = SchedState , schedDepState :: DepState -- ^ State used by the dependency function to determine when to -- remove decisions from the sleep set. - } - -instance NFData SchedState where - rnf s = rnf ( schedSleep s - , schedPrefix s - , schedBPoints s - , schedIgnore s - , schedBoundKill s - , schedDepState s - ) + } deriving Show -- | Initial scheduler state for a given prefix initialSchedState :: Map ThreadId ThreadAction @@ -378,52 +350,60 @@ type BoundFunc -- | A backtracking step is a point in the execution where another -- decision needs to be made, in order to explore interesting new -- schedules. A backtracking /function/ takes the steps identified so --- far and a point and a thread to backtrack to, and inserts at least --- that backtracking point. More may be added to compensate for the --- effects of the bounding function. For example, under pre-emption --- bounding a conservative backtracking point is added at the prior --- context switch. +-- far and a list of points and thread at that point to backtrack +-- to. More points be added to compensate for the effects of the +-- bounding function. For example, under pre-emption bounding a +-- conservative backtracking point is added at the prior context +-- switch. The bool is whether the point is conservative. Conservative +-- points are always explored, whereas non-conservative ones might be +-- skipped based on future information. -- -- In general, a backtracking function should identify one or more -- backtracking points, and then use @backtrackAt@ to do the actual -- work. type BacktrackFunc - = [BacktrackStep] -> Int -> ThreadId -> [BacktrackStep] + = [BacktrackStep] -> [(Int, Bool, ThreadId)] -> [BacktrackStep] -- | Add a backtracking point. If the thread isn't runnable, add all -- runnable threads. If the backtracking point is already present, -- don't re-add it UNLESS this would make it conservative. backtrackAt - :: (BacktrackStep -> Bool) + :: (ThreadId -> BacktrackStep -> Bool) -- ^ If this returns @True@, backtrack to all runnable threads, -- rather than just the given thread. - -> Bool - -- ^ Is this backtracking point conservative? Conservative points - -- are always explored, whereas non-conservative ones might be - -- skipped based on future information. -> BacktrackFunc -backtrackAt toAll conservative bs i tid = go bs i where - go bx@(b:rest) 0 +backtrackAt toAll bs0 = backtrackAt' . nubBy ((==) `on` fst') . sortOn fst' where + fst' (x,_,_) = x + + backtrackAt' ((i,c,t):is) = go i bs0 i c t is + backtrackAt' [] = bs0 + + go i0 (b:bs) 0 c tid is -- If the backtracking point is already present, don't re-add it, -- UNLESS this would force it to backtrack (it's conservative) -- where before it might not. - | not (toAll b) && tid `M.member` bcktRunnable b = + | not (toAll tid b) && tid `M.member` bcktRunnable b = let val = M.lookup tid $ bcktBacktracks b - in if isNothing val || (val == Just False && conservative) - then b { bcktBacktracks = backtrackTo b } : rest - else bx - + b' = if isNothing val || (val == Just False && c) + then b { bcktBacktracks = backtrackTo tid c b } + else b + in b' : case is of + ((i',c',t'):is') -> go i' bs (i'-i0-1) c' t' is' + [] -> bs -- Otherwise just backtrack to everything runnable. - | otherwise = b { bcktBacktracks = backtrackAll b } : rest - - go (b:rest) n = b : go rest (n-1) - go [] _ = error "backtrackAt: Ran out of schedule whilst backtracking!" + | otherwise = + let b' = b { bcktBacktracks = backtrackAll c b } + in b' : case is of + ((i',c',t'):is') -> go i' bs (i'-i0-1) c' t' is' + [] -> bs + go i0 (b:bs) i c tid is = b : go i0 bs (i-1) c tid is + go _ [] _ _ _ _ = err "backtrackAt" "ran out of schedule whilst backtracking!" -- Backtrack to a single thread - backtrackTo = M.insert tid conservative . bcktBacktracks + backtrackTo tid c = M.insert tid c . bcktBacktracks -- Backtrack to all runnable threads - backtrackAll = M.map (const conservative) . bcktRunnable + backtrackAll c = M.map (const c) . bcktRunnable -- | DPOR scheduler: takes a list of decisions, and maintains a trace -- including the runnable threads, and the alternative choices allowed @@ -433,17 +413,14 @@ backtrackAt toAll conservative bs i tid = go bs i where -- the prior thread if it's (1) still runnable and (2) hasn't just -- yielded. Furthermore, threads which /will/ yield are ignored in -- preference of those which will not. --- --- This forces full evaluation of the result every step, to avoid any --- possible space leaks. dporSched - :: (DepState -> (ThreadId, ThreadAction) -> (ThreadId, ThreadAction) -> Bool) + :: (DepState -> ThreadId -> ThreadAction -> ThreadId -> ThreadAction -> Bool) -- ^ Dependency function. -> BoundFunc -- ^ Bound function: returns true if that schedule prefix terminated -- with the lookahead decision fits within the bound. -> Scheduler SchedState -dporSched dependency inBound trc prior threads s = force schedule where +dporSched dependency inBound trc prior threads s = schedule where -- Pick a thread to run. schedule = case schedPrefix s of -- If there is a decision available, make it @@ -455,7 +432,7 @@ dporSched dependency inBound trc prior threads s = force schedule where [] -> let choices = restrictToBound initialise checkDep t a = case prior of - Just (tid, act) -> dependency (schedDepState s) (tid, act) (t, a) + Just (tid, act) -> dependency (schedDepState s) tid act t a Nothing -> False ssleep' = M.filterWithKey (\t a -> not $ checkDep t a) $ schedSleep s choices' = filter (`notElem` M.keys ssleep') choices @@ -470,7 +447,7 @@ dporSched dependency inBound trc prior threads s = force schedule where { schedBPoints = schedBPoints s |> (threads, rest) , schedDepState = nextDepState } - nextDepState = let ds = schedDepState s in maybe ds (updateDepState ds) prior + nextDepState = let ds = schedDepState s in maybe ds (uncurry $ updateDepState ds) prior -- Pick a new thread to run, not considering bounds. Choose the -- current thread if available and it hasn't just yielded, otherwise @@ -537,11 +514,7 @@ data DepState = DepState -- the masking state is assumed to be @Unmasked@. This nicely -- provides compatibility with dpor-0.1, where the thread IDs are -- not available. - } - -instance NFData DepState where - -- Cheats: 'MaskingState' has no 'NFData' instance. - rnf ds = rnf (depCRState ds, M.keys (depMaskState ds)) + } deriving (Eq, Show) -- | Initial dependency state. initialDepState :: DepState @@ -549,8 +522,8 @@ initialDepState = DepState M.empty M.empty -- | Update the 'CRef' buffer state with the action that has just -- happened. -updateDepState :: DepState -> (ThreadId, ThreadAction) -> DepState -updateDepState depstate (tid, act) = DepState +updateDepState :: DepState -> ThreadId -> ThreadAction -> DepState +updateDepState depstate tid act = DepState { depCRState = updateCRState act $ depCRState depstate , depMaskState = updateMaskState tid act $ depMaskState depstate } @@ -698,3 +671,17 @@ toDotFiltered check showTid showAct = digraph . go "L" where -- | Internal errors. err :: String -> String -> a err func msg = error (func ++ ": (internal error) " ++ msg) + +-- | A combination of 'partition' and 'concat'. +concatPartition :: (a -> Bool) -> [[a]] -> ([a], [a]) +{-# INLINE concatPartition #-} +-- note: `foldr (flip (foldr select))` is slow, as is `foldl (foldl +-- select))`, and `foldl'` variants. The sweet spot seems to be `foldl +-- (foldr select)` for some reason I don't really understand. +concatPartition p = foldl (foldr select) ([], []) where + -- Lazy pattern matching, got this trick from the 'partition' + -- implementation. This reduces allocation fairly significantly; I + -- do not know why. + select a ~(ts, fs) + | p a = (a:ts, fs) + | otherwise = (ts, a:fs) diff --git a/dejafu/dejafu.cabal b/dejafu/dejafu.cabal index 92093ee..e32550c 100755 --- a/dejafu/dejafu.cabal +++ b/dejafu/dejafu.cabal @@ -96,7 +96,6 @@ library build-depends: base >=4.8 && <5 , concurrency >=1.0 && <1.1 , containers >=0.5 && <0.6 - , deepseq >=1.3 && <1.5 , exceptions >=0.7 && <0.9 , monad-loops >=0.4 && <0.5 , mtl >=2.2 && <2.3