Merge branch 'subconcurrency'

This commit is contained in:
Michael Walker 2017-02-18 20:45:15 +00:00
commit 75d2b6ca73
12 changed files with 555 additions and 613 deletions

View File

@ -11,6 +11,7 @@ import Test.HUnit.DejaFu (testDejafu)
import Control.Concurrent.Classy import Control.Concurrent.Classy
import Control.Monad.STM.Class import Control.Monad.STM.Class
import Test.DejaFu.Conc (Conc, subconcurrency)
#if __GLASGOW_HASKELL__ < 710 #if __GLASGOW_HASKELL__ < 710
import Control.Applicative ((<$>), (<*>)) import Control.Applicative ((<$>), (<*>))
@ -49,6 +50,13 @@ tests =
, testGroup "Daemons" . hUnitTestToTests $ test , testGroup "Daemons" . hUnitTestToTests $ test
[ testDejafu schedDaemon "schedule daemon" $ gives' [0,1] [ testDejafu schedDaemon "schedule daemon" $ gives' [0,1]
] ]
, testGroup "Subconcurrency" . hUnitTestToTests $ test
[ testDejafu scDeadlock1 "deadlock1" $ gives' [Left Deadlock, Right ()]
, testDejafu scDeadlock2 "deadlock2" $ gives' [(Left Deadlock, ()), (Right (), ())]
, testDejafu scSuccess "success" $ gives' [Right ()]
, testDejafu scIllegal "illegal" $ gives [Left IllegalSubconcurrency]
]
] ]
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
@ -207,3 +215,40 @@ schedDaemon = do
x <- newCRef 0 x <- newCRef 0
_ <- fork $ myThreadId >> writeCRef x 1 _ <- fork $ myThreadId >> writeCRef x 1
readCRef x readCRef x
--------------------------------------------------------------------------------
-- Subconcurrency
-- | Subcomputation deadlocks sometimes.
scDeadlock1 :: Monad n => Conc n r (Either Failure ())
scDeadlock1 = do
var <- newEmptyMVar
subconcurrency $ do
void . fork $ putMVar var ()
putMVar var ()
-- | Subcomputation deadlocks sometimes, and action after it still
-- happens.
scDeadlock2 :: Monad n => Conc n r (Either Failure (), ())
scDeadlock2 = do
var <- newEmptyMVar
res <- subconcurrency $ do
void . fork $ putMVar var ()
putMVar var ()
(,) <$> pure res <*> readMVar var
-- | Subcomputation successfully completes.
scSuccess :: Monad n => Conc n r (Either Failure ())
scSuccess = do
var <- newMVar ()
subconcurrency $ do
out <- newEmptyMVar
void . fork $ takeMVar var >>= putMVar out
takeMVar out
-- | Illegal usage
scIllegal :: Monad n => Conc n r ()
scIllegal = do
var <- newEmptyMVar
void . fork $ readMVar var
void . subconcurrency $ pure ()

View File

@ -3,6 +3,7 @@
module Cases.SingleThreaded where module Cases.SingleThreaded where
import Control.Exception (ArithException(..), ArrayException(..)) import Control.Exception (ArithException(..), ArrayException(..))
import Control.Monad (void)
import Test.DejaFu (Failure(..), gives, gives') import Test.DejaFu (Failure(..), gives, gives')
import Test.Framework (Test, testGroup) import Test.Framework (Test, testGroup)
import Test.Framework.Providers.HUnit (hUnitTestToTests) import Test.Framework.Providers.HUnit (hUnitTestToTests)
@ -10,7 +11,7 @@ import Test.HUnit (test)
import Test.HUnit.DejaFu (testDejafu) import Test.HUnit.DejaFu (testDejafu)
import Control.Concurrent.Classy import Control.Concurrent.Classy
import Control.Monad.STM.Class import Test.DejaFu.Conc (Conc, subconcurrency)
import Utils import Utils
@ -58,6 +59,12 @@ tests =
[ testDejafu capsGet "get" $ gives' [True] [ testDejafu capsGet "get" $ gives' [True]
, testDejafu capsSet "set" $ gives' [True] , testDejafu capsSet "set" $ gives' [True]
] ]
, testGroup "Subconcurrency" . hUnitTestToTests $ test
[ testDejafu scDeadlock1 "deadlock1" $ gives' [Left Deadlock]
, testDejafu scDeadlock2 "deadlock2" $ gives' [(Left Deadlock, ())]
, testDejafu scSuccess "success" $ gives' [Right ()]
]
] ]
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
@ -252,3 +259,22 @@ capsSet = do
caps <- getNumCapabilities caps <- getNumCapabilities
setNumCapabilities $ caps + 1 setNumCapabilities $ caps + 1
(== caps + 1) <$> getNumCapabilities (== caps + 1) <$> getNumCapabilities
--------------------------------------------------------------------------------
-- Subconcurrency
-- | Subcomputation deadlocks.
scDeadlock1 :: Monad n => Conc n r (Either Failure ())
scDeadlock1 = subconcurrency (newEmptyMVar >>= readMVar)
-- | Subcomputation deadlocks, and action after it still happens.
scDeadlock2 :: Monad n => Conc n r (Either Failure (), ())
scDeadlock2 = do
var <- newMVar ()
(,) <$> subconcurrency (putMVar var ()) <*> readMVar var
-- | Subcomputation successfully completes.
scSuccess :: Monad n => Conc n r (Either Failure ())
scSuccess = do
var <- newMVar ()
subconcurrency (takeMVar var)

View File

@ -240,7 +240,6 @@ module Test.DejaFu
) where ) where
import Control.Arrow (first) import Control.Arrow (first)
import Control.DeepSeq (NFData(..))
import Control.Monad (when, unless) import Control.Monad (when, unless)
import Control.Monad.Ref (MonadRef) import Control.Monad.Ref (MonadRef)
import Control.Monad.ST (runST) import Control.Monad.ST (runST)
@ -400,9 +399,6 @@ defaultFail failures = Result False 0 failures ""
defaultPass :: Result a defaultPass :: Result a
defaultPass = Result True 0 [] "" defaultPass = Result True 0 [] ""
instance NFData a => NFData (Result a) where
rnf r = rnf (_pass r, _casesChecked r, _failures r, _failureMsg r)
instance Functor Result where instance Functor Result where
fmap f r = r { _failures = map (first $ fmap f) $ _failures r } fmap f r = r { _failures = map (first $ fmap f) $ _failures r }

View File

@ -60,7 +60,6 @@ module Test.DejaFu.Common
, MemType(..) , MemType(..)
) where ) where
import Control.DeepSeq (NFData(..))
import Control.Exception (MaskingState(..)) import Control.Exception (MaskingState(..))
import Data.Dynamic (Dynamic) import Data.Dynamic (Dynamic)
import Data.List (sort, nub, intercalate) import Data.List (sort, nub, intercalate)
@ -83,9 +82,6 @@ instance Show ThreadId where
show (ThreadId (Just n) _) = n show (ThreadId (Just n) _) = n
show (ThreadId Nothing i) = show i show (ThreadId Nothing i) = show i
instance NFData ThreadId where
rnf (ThreadId n i) = rnf (n, i)
-- | Every @CRef@ has a unique identifier. -- | Every @CRef@ has a unique identifier.
data CRefId = CRefId (Maybe String) Int data CRefId = CRefId (Maybe String) Int
deriving Eq deriving Eq
@ -97,9 +93,6 @@ instance Show CRefId where
show (CRefId (Just n) _) = n show (CRefId (Just n) _) = n
show (CRefId Nothing i) = show i show (CRefId Nothing i) = show i
instance NFData CRefId where
rnf (CRefId n i) = rnf (n, i)
-- | Every @MVar@ has a unique identifier. -- | Every @MVar@ has a unique identifier.
data MVarId = MVarId (Maybe String) Int data MVarId = MVarId (Maybe String) Int
deriving Eq deriving Eq
@ -111,9 +104,6 @@ instance Show MVarId where
show (MVarId (Just n) _) = n show (MVarId (Just n) _) = n
show (MVarId Nothing i) = show i show (MVarId Nothing i) = show i
instance NFData MVarId where
rnf (MVarId n i) = rnf (n, i)
-- | Every @TVar@ has a unique identifier. -- | Every @TVar@ has a unique identifier.
data TVarId = TVarId (Maybe String) Int data TVarId = TVarId (Maybe String) Int
deriving Eq deriving Eq
@ -125,9 +115,6 @@ instance Show TVarId where
show (TVarId (Just n) _) = n show (TVarId (Just n) _) = n
show (TVarId Nothing i) = show i show (TVarId Nothing i) = show i
instance NFData TVarId where
rnf (TVarId n i) = rnf (n, i)
-- | The ID of the initial thread. -- | The ID of the initial thread.
initialThread :: ThreadId initialThread :: ThreadId
initialThread = ThreadId (Just "main") 0 initialThread = ThreadId (Just "main") 0
@ -272,38 +259,10 @@ data ThreadAction =
-- ^ A '_concMessage' annotation was processed. -- ^ A '_concMessage' annotation was processed.
| Stop | Stop
-- ^ Cease execution and terminate. -- ^ Cease execution and terminate.
| Subconcurrency
-- ^ Start executing an action with @subconcurrency@.
deriving Show deriving Show
instance NFData ThreadAction where
rnf (Fork t) = rnf t
rnf (GetNumCapabilities i) = rnf i
rnf (SetNumCapabilities i) = rnf i
rnf (NewVar c) = rnf c
rnf (PutVar c ts) = rnf (c, ts)
rnf (BlockedPutVar c) = rnf c
rnf (TryPutVar c b ts) = rnf (c, b, ts)
rnf (ReadVar c) = rnf c
rnf (BlockedReadVar c) = rnf c
rnf (TakeVar c ts) = rnf (c, ts)
rnf (BlockedTakeVar c) = rnf c
rnf (TryTakeVar c b ts) = rnf (c, b, ts)
rnf (NewRef c) = rnf c
rnf (ReadRef c) = rnf c
rnf (ReadRefCas c) = rnf c
rnf (ModRef c) = rnf c
rnf (ModRefCas c) = rnf c
rnf (WriteRef c) = rnf c
rnf (CasRef c b) = rnf (c, b)
rnf (CommitRef t c) = rnf (t, c)
rnf (STM s ts) = rnf (s, ts)
rnf (BlockedSTM s) = rnf s
rnf (ThrowTo t) = rnf t
rnf (BlockedThrowTo t) = rnf t
rnf (SetMasking b m) = b `seq` m `seq` ()
rnf (ResetMasking b m) = b `seq` m `seq` ()
rnf (Message m) = m `seq` ()
rnf a = a `seq` ()
-- | Check if a @ThreadAction@ immediately blocks. -- | Check if a @ThreadAction@ immediately blocks.
isBlock :: ThreadAction -> Bool isBlock :: ThreadAction -> Bool
isBlock (BlockedThrowTo _) = True isBlock (BlockedThrowTo _) = True
@ -401,28 +360,10 @@ data Lookahead =
-- ^ Will process a _concMessage' annotation. -- ^ Will process a _concMessage' annotation.
| WillStop | WillStop
-- ^ Will cease execution and terminate. -- ^ Will cease execution and terminate.
| WillSubconcurrency
-- ^ Will execute an action with @subconcurrency@.
deriving Show deriving Show
instance NFData Lookahead where
rnf (WillSetNumCapabilities i) = rnf i
rnf (WillPutVar c) = rnf c
rnf (WillTryPutVar c) = rnf c
rnf (WillReadVar c) = rnf c
rnf (WillTakeVar c) = rnf c
rnf (WillTryTakeVar c) = rnf c
rnf (WillReadRef c) = rnf c
rnf (WillReadRefCas c) = rnf c
rnf (WillModRef c) = rnf c
rnf (WillModRefCas c) = rnf c
rnf (WillWriteRef c) = rnf c
rnf (WillCasRef c) = rnf c
rnf (WillCommitRef t c) = rnf (t, c)
rnf (WillThrowTo t) = rnf t
rnf (WillSetMasking b m) = b `seq` m `seq` ()
rnf (WillResetMasking b m) = b `seq` m `seq` ()
rnf (WillMessage m) = m `seq` ()
rnf l = l `seq` ()
-- | Convert a 'ThreadAction' into a 'Lookahead': \"rewind\" what has -- | Convert a 'ThreadAction' into a 'Lookahead': \"rewind\" what has
-- happened. 'Killed' has no 'Lookahead' counterpart. -- happened. 'Killed' has no 'Lookahead' counterpart.
rewind :: ThreadAction -> Maybe Lookahead rewind :: ThreadAction -> Maybe Lookahead
@ -462,6 +403,7 @@ rewind LiftIO = Just WillLiftIO
rewind Return = Just WillReturn rewind Return = Just WillReturn
rewind (Message m) = Just (WillMessage m) rewind (Message m) = Just (WillMessage m)
rewind Stop = Just WillStop rewind Stop = Just WillStop
rewind Subconcurrency = Just WillSubconcurrency
-- | Check if an operation could enable another thread. -- | Check if an operation could enable another thread.
willRelease :: Lookahead -> Bool willRelease :: Lookahead -> Bool
@ -508,17 +450,6 @@ data ActionType =
-- communication. -- communication.
deriving (Eq, Show) deriving (Eq, Show)
instance NFData ActionType where
rnf (UnsynchronisedRead r) = rnf r
rnf (UnsynchronisedWrite r) = rnf r
rnf (PartiallySynchronisedCommit r) = rnf r
rnf (PartiallySynchronisedWrite r) = rnf r
rnf (PartiallySynchronisedModify r) = rnf r
rnf (SynchronisedModify r) = rnf r
rnf (SynchronisedRead c) = rnf c
rnf (SynchronisedWrite c) = rnf c
rnf a = a `seq` ()
-- | Check if an action imposes a write barrier. -- | Check if an action imposes a write barrier.
isBarrier :: ActionType -> Bool isBarrier :: ActionType -> Bool
isBarrier (SynchronisedModify _) = True isBarrier (SynchronisedModify _) = True
@ -611,13 +542,6 @@ data TAction =
-- ^ Terminate successfully and commit effects. -- ^ Terminate successfully and commit effects.
deriving (Eq, Show) deriving (Eq, Show)
instance NFData TAction where
rnf (TRead v) = rnf v
rnf (TWrite v) = rnf v
rnf (TCatch s m) = rnf (s, m)
rnf (TOrElse s m) = rnf (s, m)
rnf a = a `seq` ()
------------------------------------------------------------------------------- -------------------------------------------------------------------------------
-- Traces -- Traces
@ -640,11 +564,6 @@ data Decision =
-- ^ Pre-empt the running thread, and switch to another. -- ^ Pre-empt the running thread, and switch to another.
deriving (Eq, Show) deriving (Eq, Show)
instance NFData Decision where
rnf (Start tid) = rnf tid
rnf (SwitchTo tid) = rnf tid
rnf d = d `seq` ()
-- | Pretty-print a trace, including a key of the thread IDs (not -- | Pretty-print a trace, including a key of the thread IDs (not
-- including thread 0). Each line of the key is indented by two -- including thread 0). Each line of the key is indented by two
-- spaces. -- spaces.
@ -675,15 +594,15 @@ showTrace trc = intercalate "\n" $ concatMap go trc : strkey where
preEmpCount :: [(Decision, ThreadAction)] preEmpCount :: [(Decision, ThreadAction)]
-> (Decision, Lookahead) -> (Decision, Lookahead)
-> Int -> Int
preEmpCount ts (d, _) = go initialThread Nothing ts where preEmpCount (x:xs) (d, _) = go initialThread x xs where
go _ (Just Yield) ((SwitchTo t, a):rest) = go t (Just a) rest go _ (_, Yield) (r@(SwitchTo t, _):rest) = go t r rest
go tid prior ((SwitchTo t, a):rest) go tid prior (r@(SwitchTo t, _):rest)
| isCommitThread t = go tid prior (skip rest) | isCommitThread t = go tid prior (skip rest)
| otherwise = 1 + go t (Just a) rest | otherwise = 1 + go t r rest
go _ _ ((Start t, a):rest) = go t (Just a) rest go _ _ (r@(Start t, _):rest) = go t r rest
go tid _ ((Continue, a):rest) = go tid (Just a) rest go tid _ (r@(Continue, _):rest) = go tid r rest
go _ prior [] = case (prior, d) of go _ prior [] = case (prior, d) of
(Just Yield, SwitchTo _) -> 0 ((_, Yield), SwitchTo _) -> 0
(_, SwitchTo _) -> 1 (_, SwitchTo _) -> 1
_ -> 0 _ -> 0
@ -694,6 +613,7 @@ preEmpCount ts (d, _) = go initialThread Nothing ts where
skip = dropWhile (not . isContextSwitch . fst) skip = dropWhile (not . isContextSwitch . fst)
isContextSwitch Continue = False isContextSwitch Continue = False
isContextSwitch _ = True isContextSwitch _ = True
preEmpCount [] _ = 0
------------------------------------------------------------------------------- -------------------------------------------------------------------------------
-- Failures -- Failures
@ -716,11 +636,11 @@ data Failure =
-- ^ The computation became blocked indefinitely on @TVar@s. -- ^ The computation became blocked indefinitely on @TVar@s.
| UncaughtException | UncaughtException
-- ^ An uncaught exception bubbled to the top of the computation. -- ^ An uncaught exception bubbled to the top of the computation.
| IllegalSubconcurrency
-- ^ Calls to @subconcurrency@ were nested, or attempted when
-- multiple threads existed.
deriving (Eq, Show, Read, Ord, Enum, Bounded) deriving (Eq, Show, Read, Ord, Enum, Bounded)
instance NFData Failure where
rnf f = f `seq` ()
-- | Pretty-print a failure -- | Pretty-print a failure
showFail :: Failure -> String showFail :: Failure -> String
showFail Abort = "[abort]" showFail Abort = "[abort]"
@ -728,6 +648,7 @@ showFail Deadlock = "[deadlock]"
showFail STMDeadlock = "[stm-deadlock]" showFail STMDeadlock = "[stm-deadlock]"
showFail InternalError = "[internal-error]" showFail InternalError = "[internal-error]"
showFail UncaughtException = "[exception]" showFail UncaughtException = "[exception]"
showFail IllegalSubconcurrency = "[illegal-subconcurrency]"
------------------------------------------------------------------------------- -------------------------------------------------------------------------------
-- Memory Models -- Memory Models
@ -751,9 +672,6 @@ data MemType =
-- created. -- created.
deriving (Eq, Show, Read, Ord, Enum, Bounded) deriving (Eq, Show, Read, Ord, Enum, Bounded)
instance NFData MemType where
rnf m = m `seq` ()
------------------------------------------------------------------------------- -------------------------------------------------------------------------------
-- Utilities -- Utilities

View File

@ -28,6 +28,7 @@ module Test.DejaFu.Conc
, Failure(..) , Failure(..)
, MemType(..) , MemType(..)
, runConcurrent , runConcurrent
, subconcurrency
-- * Execution traces -- * Execution traces
, Trace , Trace
@ -49,12 +50,11 @@ import Control.Exception (MaskingState(..))
import qualified Control.Monad.Base as Ba import qualified Control.Monad.Base as Ba
import qualified Control.Monad.Catch as Ca import qualified Control.Monad.Catch as Ca
import qualified Control.Monad.IO.Class as IO import qualified Control.Monad.IO.Class as IO
import Control.Monad.Ref (MonadRef, newRef, readRef, writeRef) import Control.Monad.Ref (MonadRef,)
import Control.Monad.ST (ST) import Control.Monad.ST (ST)
import Data.Dynamic (toDyn) import Data.Dynamic (toDyn)
import qualified Data.Foldable as F
import Data.IORef (IORef) import Data.IORef (IORef)
import qualified Data.Map.Strict as M
import Data.Maybe (fromJust)
import Data.STRef (STRef) import Data.STRef (STRef)
import Test.DejaFu.Schedule import Test.DejaFu.Schedule
@ -62,13 +62,12 @@ import qualified Control.Monad.Conc.Class as C
import Test.DejaFu.Common import Test.DejaFu.Common
import Test.DejaFu.Conc.Internal import Test.DejaFu.Conc.Internal
import Test.DejaFu.Conc.Internal.Common import Test.DejaFu.Conc.Internal.Common
import Test.DejaFu.Conc.Internal.Threading
import Test.DejaFu.STM import Test.DejaFu.STM
{-# ANN module ("HLint: ignore Avoid lambda" :: String) #-} {-# ANN module ("HLint: ignore Avoid lambda" :: String) #-}
{-# ANN module ("HLint: ignore Use const" :: String) #-} {-# ANN module ("HLint: ignore Use const" :: String) #-}
newtype Conc n r a = C { unC :: M n r (STMLike n r) a } deriving (Functor, Applicative, Monad) newtype Conc n r a = C { unC :: M n r a } deriving (Functor, Applicative, Monad)
-- | A 'MonadConc' implementation using @ST@, this should be preferred -- | A 'MonadConc' implementation using @ST@, this should be preferred
-- if you do not need 'liftIO'. -- if you do not need 'liftIO'.
@ -77,10 +76,10 @@ type ConcST t = Conc (ST t) (STRef t)
-- | A 'MonadConc' implementation using @IO@. -- | A 'MonadConc' implementation using @IO@.
type ConcIO = Conc IO IORef type ConcIO = Conc IO IORef
toConc :: ((a -> Action n r (STMLike n r)) -> Action n r (STMLike n r)) -> Conc n r a toConc :: ((a -> Action n r) -> Action n r) -> Conc n r a
toConc = C . cont toConc = C . cont
wrap :: (M n r (STMLike n r) a -> M n r (STMLike n r) a) -> Conc n r a -> Conc n r a wrap :: (M n r a -> M n r a) -> Conc n r a -> Conc n r a
wrap f = C . f . unC wrap f = C . f . unC
instance IO.MonadIO ConcIO where instance IO.MonadIO ConcIO where
@ -181,20 +180,15 @@ runConcurrent :: MonadRef r n
-> s -> s
-> Conc n r a -> Conc n r a
-> n (Either Failure a, s, Trace) -> n (Either Failure a, s, Trace)
runConcurrent sched memtype s (C conc) = do runConcurrent sched memtype s ma = do
ref <- newRef Nothing (res, s', trace) <- runConcurrency sched memtype s (unC ma)
pure (res, s', F.toList trace)
let c = runCont conc (AStop . writeRef ref . Just . Right) -- | Run a concurrent computation and return its result.
let threads = launch' Unmasked initialThread (const c) M.empty --
-- This can only be called in the main thread, when no other threads
(s', trace) <- runThreads runTransaction -- exist. Calls to 'subconcurrency' cannot be nested. Violating either
sched -- of these conditions will result in the computation failing with
memtype -- @IllegalSubconcurrency@.
s subconcurrency :: Conc n r a -> Conc n r (Either Failure a)
threads subconcurrency ma = toConc (ASub (unC ma))
initialIdSource
ref
out <- readRef ref
pure (fromJust out, s', reverse trace)

View File

@ -16,19 +16,23 @@
module Test.DejaFu.Conc.Internal where module Test.DejaFu.Conc.Internal where
import Control.Exception (MaskingState(..), toException) import Control.Exception (MaskingState(..), toException)
import Control.Monad.Ref (MonadRef, newRef, writeRef) import Control.Monad.Ref (MonadRef, newRef, readRef, writeRef)
import qualified Data.Foldable as F
import Data.Functor (void) import Data.Functor (void)
import Data.List (sort) import Data.List (sort)
import Data.List.NonEmpty (NonEmpty(..), fromList) import Data.List.NonEmpty (NonEmpty(..), fromList)
import qualified Data.Map.Strict as M import qualified Data.Map.Strict as M
import Data.Maybe (fromJust, isJust, isNothing, listToMaybe) import Data.Maybe (fromJust, isJust, isNothing)
import Data.Monoid ((<>))
import Data.Sequence (Seq, (<|))
import qualified Data.Sequence as Seq
import Test.DejaFu.Common import Test.DejaFu.Common
import Test.DejaFu.Conc.Internal.Common import Test.DejaFu.Conc.Internal.Common
import Test.DejaFu.Conc.Internal.Memory import Test.DejaFu.Conc.Internal.Memory
import Test.DejaFu.Conc.Internal.Threading import Test.DejaFu.Conc.Internal.Threading
import Test.DejaFu.Schedule import Test.DejaFu.Schedule
import Test.DejaFu.STM (Result(..)) import Test.DejaFu.STM (Result(..), runTransaction)
{-# ANN module ("HLint: ignore Use record patterns" :: String) #-} {-# ANN module ("HLint: ignore Use record patterns" :: String) #-}
{-# ANN module ("HLint: ignore Use const" :: String) #-} {-# ANN module ("HLint: ignore Use const" :: String) #-}
@ -36,40 +40,69 @@ import Test.DejaFu.STM (Result(..))
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
-- * Execution -- * Execution
-- | 'Trace' but as a sequence.
type SeqTrace
= Seq (Decision, [(ThreadId, NonEmpty Lookahead)], ThreadAction)
-- | Run a concurrent computation with a given 'Scheduler' and initial
-- state, returning a failure reason on error. Also returned is the
-- final state of the scheduler, and an execution trace.
runConcurrency :: MonadRef r n
=> Scheduler g
-> MemType
-> g
-> M n r a
-> n (Either Failure a, g, SeqTrace)
runConcurrency sched memtype g ma = do
ref <- newRef Nothing
let c = runCont ma (AStop . writeRef ref . Just . Right)
let threads = launch' Unmasked initialThread (const c) M.empty
let ctx = Context { cSchedState = g, cIdSource = initialIdSource, cThreads = threads, cWriteBuf = emptyBuffer, cCaps = 2 }
(finalCtx, trace) <- runThreads sched memtype ref ctx
out <- readRef ref
pure (fromJust out, cSchedState finalCtx, trace)
-- | The context a collection of threads are running in.
data Context n r g = Context
{ cSchedState :: g
, cIdSource :: IdSource
, cThreads :: Threads n r
, cWriteBuf :: WriteBuffer r
, cCaps :: Int
}
-- | Run a collection of threads, until there are no threads left. -- | Run a collection of threads, until there are no threads left.
-- runThreads :: MonadRef r n
-- Note: this returns the trace in reverse order, because it's more => Scheduler g -> MemType -> r (Maybe (Either Failure a)) -> Context n r g -> n (Context n r g, SeqTrace)
-- efficient to prepend to a list than append. As this function isn't runThreads sched memtype ref = go Seq.empty [] Nothing where
-- exposed to users of the library, this is just an internal gotcha to -- sofar is the 'SeqTrace', sofarSched is the @[(Decision,
-- watch out for. -- ThreadAction)]@ trace the scheduler needs.
runThreads :: MonadRef r n => (forall x. s x -> IdSource -> n (Result x, IdSource, TTrace)) go sofar sofarSched prior ctx
-> Scheduler g -> MemType -> g -> Threads n r s -> IdSource -> r (Maybe (Either Failure a)) -> n (g, Trace) | isTerminated = stop ctx
runThreads runstm sched memtype origg origthreads idsrc ref = go idsrc [] Nothing origg origthreads emptyBuffer 2 where | isDeadlocked = die Deadlock ctx
go idSource sofar prior g threads wb caps | isSTMLocked = die STMDeadlock ctx
| isTerminated = stop g | isAborted = die Abort $ ctx { cSchedState = g' }
| isDeadlocked = die g Deadlock | isNonexistant = die InternalError $ ctx { cSchedState = g' }
| isSTMLocked = die g STMDeadlock | isBlocked = die InternalError $ ctx { cSchedState = g' }
| isAborted = die g' Abort
| isNonexistant = die g' InternalError
| isBlocked = die g' InternalError
| otherwise = do | otherwise = do
stepped <- stepThread runstm memtype (_continuation $ fromJust thread) idSource chosen threads wb caps stepped <- stepThread sched memtype chosen (_continuation $ fromJust thread) $ ctx { cSchedState = g' }
case stepped of case stepped of
Right (threads', idSource', act, wb', caps') -> loop threads' idSource' act wb' caps' Right (ctx', actOrTrc) -> loop actOrTrc ctx'
Left UncaughtException Left UncaughtException
| chosen == initialThread -> die g' UncaughtException | chosen == initialThread -> die UncaughtException $ ctx { cSchedState = g' }
| otherwise -> loop (kill chosen threads) idSource Killed wb caps | otherwise -> loop (Right Killed) $ ctx { cThreads = kill chosen threadsc, cSchedState = g' }
Left failure -> die failure $ ctx { cSchedState = g' }
Left failure -> die g' failure
where where
(choice, g') = sched (map (\(d,_,a) -> (d,a)) $ reverse sofar) ((\p (_,_,a) -> (p,a)) <$> prior <*> listToMaybe sofar) (fromList $ map (\(t,l:|_) -> (t,l)) runnable') g (choice, g') = sched sofarSched prior (fromList $ map (\(t,l:|_) -> (t,l)) runnable') (cSchedState ctx)
chosen = fromJust choice chosen = fromJust choice
runnable' = [(t, nextActions t) | t <- sort $ M.keys runnable] runnable' = [(t, nextActions t) | t <- sort $ M.keys runnable]
runnable = M.filter (isNothing . _blocking) threadsc runnable = M.filter (isNothing . _blocking) threadsc
thread = M.lookup chosen threadsc thread = M.lookup chosen threadsc
threadsc = addCommitThreads wb threads threadsc = addCommitThreads (cWriteBuf ctx) threads
threads = cThreads ctx
isAborted = isNothing choice isAborted = isNothing choice
isBlocked = isJust . _blocking $ fromJust thread isBlocked = isJust . _blocking $ fromJust thread
isNonexistant = isNothing thread isNonexistant = isNothing thread
@ -87,299 +120,262 @@ runThreads runstm sched memtype origg origthreads idsrc ref = go idsrc [] Nothin
_ -> thrd _ -> thrd
decision decision
| Just chosen == prior = Continue | Just chosen == (fst <$> prior) = Continue
| prior `notElem` map (Just . fst) runnable' = Start chosen | (fst <$> prior) `notElem` map (Just . fst) runnable' = Start chosen
| otherwise = SwitchTo chosen | otherwise = SwitchTo chosen
nextActions t = lookahead . _continuation . fromJust $ M.lookup t threadsc nextActions t = lookahead . _continuation . fromJust $ M.lookup t threadsc
stop outg = pure (outg, sofar) stop finalCtx = pure (finalCtx, sofar)
die outg reason = writeRef ref (Just $ Left reason) >> stop outg die reason finalCtx = writeRef ref (Just $ Left reason) >> stop finalCtx
loop threads' idSource' act wb' = loop trcOrAct ctx' =
let sofar' = ((decision, runnable', act) : sofar) let (act, trc) = case trcOrAct of
threads'' = if (interruptible <$> M.lookup chosen threads') /= Just False then unblockWaitingOn chosen threads' else threads' Left (a, as) -> (a, (decision, runnable', a) <| as)
in go idSource' sofar' (Just chosen) g' (delCommitThreads threads'') wb' Right a -> (a, Seq.singleton (decision, runnable', a))
threads' = if (interruptible <$> M.lookup chosen (cThreads ctx')) /= Just False
then unblockWaitingOn chosen (cThreads ctx')
else cThreads ctx'
sofar' = sofar <> trc
sofarSched' = sofarSched <> map (\(d,_,a) -> (d,a)) (F.toList trc)
prior' = Just (chosen, act)
in go sofar' sofarSched' prior' $ ctx' { cThreads = delCommitThreads threads' }
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
-- * Single-step execution -- * Single-step execution
-- | Run a single thread one step, by dispatching on the type of -- | Run a single thread one step, by dispatching on the type of
-- 'Action'. -- 'Action'.
stepThread :: forall n r s. MonadRef r n stepThread :: forall n r g. MonadRef r n
=> (forall x. s x -> IdSource -> n (Result x, IdSource, TTrace)) => Scheduler g
-- ^ Run a 'MonadSTM' transaction atomically. -- ^ The scheduler.
-> MemType -> MemType
-- ^ The memory model -- ^ The memory model to use.
-> Action n r s
-- ^ Action to step
-> IdSource
-- ^ Source of fresh IDs
-> ThreadId -> ThreadId
-- ^ ID of the current thread -- ^ ID of the current thread
-> Threads n r s -> Action n r
-- ^ Current state of threads -- ^ Action to step
-> WriteBuffer r -> Context n r g
-- ^ @CRef@ write buffer -- ^ The execution context.
-> Int -> n (Either Failure (Context n r g, Either (ThreadAction, SeqTrace) ThreadAction))
-- ^ The number of capabilities stepThread sched memtype tid action ctx = case action of
-> n (Either Failure (Threads n r s, IdSource, ThreadAction, WriteBuffer r, Int)) -- start a new thread, assigning it the next 'ThreadId'
stepThread runstm memtype action idSource tid threads wb caps = case action of AFork n a b -> pure . Right $
AFork n a b -> stepFork n a b let threads' = launch tid newtid a (cThreads ctx)
AMyTId c -> stepMyTId c (idSource', newtid) = nextTId n (cIdSource ctx)
AGetNumCapabilities c -> stepGetNumCapabilities c in (ctx { cThreads = goto (b newtid) tid threads', cIdSource = idSource' }, Right (Fork newtid))
ASetNumCapabilities i c -> stepSetNumCapabilities i c
AYield c -> stepYield c
ANewVar n c -> stepNewVar n c
APutVar var a c -> stepPutVar var a c
ATryPutVar var a c -> stepTryPutVar var a c
AReadVar var c -> stepReadVar var c
ATakeVar var c -> stepTakeVar var c
ATryTakeVar var c -> stepTryTakeVar var c
ANewRef n a c -> stepNewRef n a c
AReadRef ref c -> stepReadRef ref c
AReadRefCas ref c -> stepReadRefCas ref c
AModRef ref f c -> stepModRef ref f c
AModRefCas ref f c -> stepModRefCas ref f c
AWriteRef ref a c -> stepWriteRef ref a c
ACasRef ref tick a c -> stepCasRef ref tick a c
ACommit t c -> stepCommit t c
AAtom stm c -> stepAtom stm c
ALift na -> stepLift na
AThrow e -> stepThrow e
AThrowTo t e c -> stepThrowTo t e c
ACatching h ma c -> stepCatching h ma c
APopCatching a -> stepPopCatching a
AMasking m ma c -> stepMasking m ma c
AResetMask b1 b2 m c -> stepResetMask b1 b2 m c
AReturn c -> stepReturn c
AMessage m c -> stepMessage m c
AStop na -> stepStop na
where -- get the 'ThreadId' of the current thread
-- | Start a new thread, assigning it the next 'ThreadId' AMyTId c -> simple (goto (c tid) tid (cThreads ctx)) MyThreadId
--
-- Explicit type signature needed for GHC 8. Looks like the
-- impredicative polymorphism checks got stronger.
stepFork :: String
-> ((forall b. M n r s b -> M n r s b) -> Action n r s)
-> (ThreadId -> Action n r s)
-> n (Either Failure (Threads n r s, IdSource, ThreadAction, WriteBuffer r, Int))
stepFork n a b = return $ Right (goto (b newtid) tid threads', idSource', Fork newtid, wb, caps) where
threads' = launch tid newtid a threads
(idSource', newtid) = nextTId n idSource
-- | Get the 'ThreadId' of the current thread -- get the number of capabilities
stepMyTId c = simple (goto (c tid) tid threads) MyThreadId AGetNumCapabilities c -> simple (goto (c (cCaps ctx)) tid (cThreads ctx)) $ GetNumCapabilities (cCaps ctx)
-- | Get the number of capabilities -- set the number of capabilities
stepGetNumCapabilities c = simple (goto (c caps) tid threads) $ GetNumCapabilities caps ASetNumCapabilities i c -> pure . Right $
(ctx { cThreads = goto c tid (cThreads ctx), cCaps = i }, Right (SetNumCapabilities i))
-- | Set the number of capabilities -- yield the current thread
stepSetNumCapabilities i c = return $ Right (goto c tid threads, idSource, SetNumCapabilities i, wb, i) AYield c -> simple (goto c tid (cThreads ctx)) Yield
-- | Yield the current thread -- create a new @MVar@, using the next 'MVarId'.
stepYield c = simple (goto c tid threads) Yield ANewVar n c -> do
let (idSource', newmvid) = nextMVId n (cIdSource ctx)
ref <- newRef Nothing
let mvar = MVar newmvid ref
pure $ Right (ctx { cThreads = goto (c mvar) tid (cThreads ctx), cIdSource = idSource' }, Right (NewVar newmvid))
-- | Put a value into a @MVar@, blocking the thread until it's -- put a value into a @MVar@, blocking the thread until it's empty.
-- empty. APutVar cvar@(MVar cvid _) a c -> synchronised $ do
stepPutVar cvar@(MVar cvid _) a c = synchronised $ do (success, threads', woken) <- putIntoMVar cvar a c tid (cThreads ctx)
(success, threads', woken) <- putIntoMVar cvar a c tid threads
simple threads' $ if success then PutVar cvid woken else BlockedPutVar cvid simple threads' $ if success then PutVar cvid woken else BlockedPutVar cvid
-- | Try to put a value into a @MVar@, without blocking. -- try to put a value into a @MVar@, without blocking.
stepTryPutVar cvar@(MVar cvid _) a c = synchronised $ do ATryPutVar cvar@(MVar cvid _) a c -> synchronised $ do
(success, threads', woken) <- tryPutIntoMVar cvar a c tid threads (success, threads', woken) <- tryPutIntoMVar cvar a c tid (cThreads ctx)
simple threads' $ TryPutVar cvid success woken simple threads' $ TryPutVar cvid success woken
-- | Get the value from a @MVar@, without emptying, blocking the -- get the value from a @MVar@, without emptying, blocking the
-- thread until it's full. -- thread until it's full.
stepReadVar cvar@(MVar cvid _) c = synchronised $ do AReadVar cvar@(MVar cvid _) c -> synchronised $ do
(success, threads', _) <- readFromMVar cvar c tid threads (success, threads', _) <- readFromMVar cvar c tid (cThreads ctx)
simple threads' $ if success then ReadVar cvid else BlockedReadVar cvid simple threads' $ if success then ReadVar cvid else BlockedReadVar cvid
-- | Take the value from a @MVar@, blocking the thread until it's -- take the value from a @MVar@, blocking the thread until it's
-- full. -- full.
stepTakeVar cvar@(MVar cvid _) c = synchronised $ do ATakeVar cvar@(MVar cvid _) c -> synchronised $ do
(success, threads', woken) <- takeFromMVar cvar c tid threads (success, threads', woken) <- takeFromMVar cvar c tid (cThreads ctx)
simple threads' $ if success then TakeVar cvid woken else BlockedTakeVar cvid simple threads' $ if success then TakeVar cvid woken else BlockedTakeVar cvid
-- | Try to take the value from a @MVar@, without blocking. -- try to take the value from a @MVar@, without blocking.
stepTryTakeVar cvar@(MVar cvid _) c = synchronised $ do ATryTakeVar cvar@(MVar cvid _) c -> synchronised $ do
(success, threads', woken) <- tryTakeFromMVar cvar c tid threads (success, threads', woken) <- tryTakeFromMVar cvar c tid (cThreads ctx)
simple threads' $ TryTakeVar cvid success woken simple threads' $ TryTakeVar cvid success woken
-- | Read from a @CRef@. -- create a new @CRef@, using the next 'CRefId'.
stepReadRef cref@(CRef crid _) c = do ANewRef n a c -> do
let (idSource', newcrid) = nextCRId n (cIdSource ctx)
ref <- newRef (M.empty, 0, a)
let cref = CRef newcrid ref
pure $ Right (ctx { cThreads = goto (c cref) tid (cThreads ctx), cIdSource = idSource' }, Right (NewRef newcrid))
-- read from a @CRef@.
AReadRef cref@(CRef crid _) c -> do
val <- readCRef cref tid val <- readCRef cref tid
simple (goto (c val) tid threads) $ ReadRef crid simple (goto (c val) tid (cThreads ctx)) $ ReadRef crid
-- | Read from a @CRef@ for future compare-and-swap operations. -- read from a @CRef@ for future compare-and-swap operations.
stepReadRefCas cref@(CRef crid _) c = do AReadRefCas cref@(CRef crid _) c -> do
tick <- readForTicket cref tid tick <- readForTicket cref tid
simple (goto (c tick) tid threads) $ ReadRefCas crid simple (goto (c tick) tid (cThreads ctx)) $ ReadRefCas crid
-- | Modify a @CRef@. -- modify a @CRef@.
stepModRef cref@(CRef crid _) f c = synchronised $ do AModRef cref@(CRef crid _) f c -> synchronised $ do
(new, val) <- f <$> readCRef cref tid (new, val) <- f <$> readCRef cref tid
writeImmediate cref new writeImmediate cref new
simple (goto (c val) tid threads) $ ModRef crid simple (goto (c val) tid (cThreads ctx)) $ ModRef crid
-- | Modify a @CRef@ using a compare-and-swap. -- modify a @CRef@ using a compare-and-swap.
stepModRefCas cref@(CRef crid _) f c = synchronised $ do AModRefCas cref@(CRef crid _) f c -> synchronised $ do
tick@(Ticket _ _ old) <- readForTicket cref tid tick@(Ticket _ _ old) <- readForTicket cref tid
let (new, val) = f old let (new, val) = f old
void $ casCRef cref tid tick new void $ casCRef cref tid tick new
simple (goto (c val) tid threads) $ ModRefCas crid simple (goto (c val) tid (cThreads ctx)) $ ModRefCas crid
-- | Write to a @CRef@ without synchronising -- write to a @CRef@ without synchronising.
stepWriteRef cref@(CRef crid _) a c = case memtype of AWriteRef cref@(CRef crid _) a c -> case memtype of
-- Write immediately. -- write immediately.
SequentialConsistency -> do SequentialConsistency -> do
writeImmediate cref a writeImmediate cref a
simple (goto c tid threads) $ WriteRef crid simple (goto c tid (cThreads ctx)) $ WriteRef crid
-- add to buffer using thread id.
-- Add to buffer using thread id.
TotalStoreOrder -> do TotalStoreOrder -> do
wb' <- bufferWrite wb (tid, Nothing) cref a wb' <- bufferWrite (cWriteBuf ctx) (tid, Nothing) cref a
return $ Right (goto c tid threads, idSource, WriteRef crid, wb', caps) pure $ Right (ctx { cThreads = goto c tid (cThreads ctx), cWriteBuf = wb' }, Right (WriteRef crid))
-- add to buffer using both thread id and cref id
-- Add to buffer using both thread id and cref id
PartialStoreOrder -> do PartialStoreOrder -> do
wb' <- bufferWrite wb (tid, Just crid) cref a wb' <- bufferWrite (cWriteBuf ctx) (tid, Just crid) cref a
return $ Right (goto c tid threads, idSource, WriteRef crid, wb', caps) pure $ Right (ctx { cThreads = goto c tid (cThreads ctx), cWriteBuf = wb' }, Right (WriteRef crid))
-- | Perform a compare-and-swap on a @CRef@. -- perform a compare-and-swap on a @CRef@.
stepCasRef cref@(CRef crid _) tick a c = synchronised $ do ACasRef cref@(CRef crid _) tick a c -> synchronised $ do
(suc, tick') <- casCRef cref tid tick a (suc, tick') <- casCRef cref tid tick a
simple (goto (c (suc, tick')) tid threads) $ CasRef crid suc simple (goto (c (suc, tick')) tid (cThreads ctx)) $ CasRef crid suc
-- | Commit a @CRef@ write -- commit a @CRef@ write
stepCommit t c = do ACommit t c -> do
wb' <- case memtype of wb' <- case memtype of
-- Shouldn't ever get here -- shouldn't ever get here
SequentialConsistency -> SequentialConsistency ->
error "Attempting to commit under SequentialConsistency" error "Attempting to commit under SequentialConsistency"
-- commit using the thread id.
TotalStoreOrder -> commitWrite (cWriteBuf ctx) (t, Nothing)
-- commit using the cref id.
PartialStoreOrder -> commitWrite (cWriteBuf ctx) (t, Just c)
pure $ Right (ctx { cWriteBuf = wb' }, Right (CommitRef t c))
-- Commit using the thread id. -- run a STM transaction atomically.
TotalStoreOrder -> commitWrite wb (t, Nothing) AAtom stm c -> synchronised $ do
(res, idSource', trace) <- runTransaction stm (cIdSource ctx)
-- Commit using the cref id.
PartialStoreOrder -> commitWrite wb (t, Just c)
return $ Right (threads, idSource, CommitRef t c, wb', caps)
-- | Run a STM transaction atomically.
stepAtom stm c = synchronised $ do
(res, idSource', trace) <- runstm stm idSource
case res of case res of
Success _ written val -> Success _ written val ->
let (threads', woken) = wake (OnTVar written) threads let (threads', woken) = wake (OnTVar written) (cThreads ctx)
in return $ Right (goto (c val) tid threads', idSource', STM trace woken, wb, caps) in pure $ Right (ctx { cThreads = goto (c val) tid threads', cIdSource = idSource' }, Right (STM trace woken))
Retry touched -> Retry touched ->
let threads' = block (OnTVar touched) tid threads let threads' = block (OnTVar touched) tid (cThreads ctx)
in return $ Right (threads', idSource', BlockedSTM trace, wb, caps) in pure $ Right (ctx { cThreads = threads', cIdSource = idSource'}, Right (BlockedSTM trace))
Exception e -> do Exception e -> do
res' <- stepThrow e res' <- stepThrow e
return $ case res' of pure $ case res' of
Right (threads', _, _, _, _) -> Right (threads', idSource', Throw, wb, caps) Right (ctx', _) -> Right (ctx' { cIdSource = idSource' }, Right Throw)
Left err -> Left err Left err -> Left err
-- | Run a subcomputation in an exception-catching context. -- lift an action from the underlying monad into the @Conc@
stepCatching h ma c = simple threads' Catching where -- computation.
a = runCont ma (APopCatching . c) ALift na -> do
e exc = runCont (h exc) (APopCatching . c) a <- na
simple (goto a tid (cThreads ctx)) LiftIO
threads' = goto a tid (catching e tid threads) -- throw an exception, and propagate it to the appropriate
-- | Pop the top exception handler from the thread's stack.
stepPopCatching a = simple threads' PopCatching where
threads' = goto a tid (uncatching tid threads)
-- | Throw an exception, and propagate it to the appropriate
-- handler. -- handler.
stepThrow e = AThrow e -> stepThrow e
case propagate (toException e) tid threads of
Just threads' -> simple threads' Throw
Nothing -> return $ Left UncaughtException
-- | Throw an exception to the target thread, and propagate it to -- throw an exception to the target thread, and propagate it to
-- the appropriate handler. -- the appropriate handler.
stepThrowTo t e c = synchronised $ AThrowTo t e c -> synchronised $
let threads' = goto c tid threads let threads' = goto c tid (cThreads ctx)
blocked = block (OnMask t) tid threads blocked = block (OnMask t) tid (cThreads ctx)
in case M.lookup t threads of in case M.lookup t (cThreads ctx) of
Just thread Just thread
| interruptible thread -> case propagate (toException e) t threads' of | interruptible thread -> case propagate (toException e) t threads' of
Just threads'' -> simple threads'' $ ThrowTo t Just threads'' -> simple threads'' $ ThrowTo t
Nothing Nothing
| t == initialThread -> return $ Left UncaughtException | t == initialThread -> pure $ Left UncaughtException
| otherwise -> simple (kill t threads') $ ThrowTo t | otherwise -> simple (kill t threads') $ ThrowTo t
| otherwise -> simple blocked $ BlockedThrowTo t | otherwise -> simple blocked $ BlockedThrowTo t
Nothing -> simple threads' $ ThrowTo t Nothing -> simple threads' $ ThrowTo t
-- | Execute a subcomputation with a new masking state, and give -- run a subcomputation in an exception-catching context.
-- it a function to run a computation with the current masking ACatching h ma c ->
-- state. let a = runCont ma (APopCatching . c)
-- e exc = runCont (h exc) (APopCatching . c)
-- Explicit type sig necessary for checking in the prescence of threads' = goto a tid (catching e tid (cThreads ctx))
-- 'umask', sadly. in simple threads' Catching
stepMasking :: MaskingState
-> ((forall b. M n r s b -> M n r s b) -> M n r s a)
-> (a -> Action n r s)
-> n (Either Failure (Threads n r s, IdSource, ThreadAction, WriteBuffer r, Int))
stepMasking m ma c = simple threads' $ SetMasking False m where
a = runCont (ma umask) (AResetMask False False m' . c)
m' = _masking . fromJust $ M.lookup tid threads -- pop the top exception handler from the thread's stack.
umask mb = resetMask True m' >> mb >>= \b -> resetMask False m >> return b APopCatching a ->
let threads' = goto a tid (uncatching tid (cThreads ctx))
in simple threads' PopCatching
-- execute a subcomputation with a new masking state, and give it
-- a function to run a computation with the current masking state.
AMasking m ma c ->
let a = runCont (ma umask) (AResetMask False False m' . c)
m' = _masking . fromJust $ M.lookup tid (cThreads ctx)
umask mb = resetMask True m' >> mb >>= \b -> resetMask False m >> pure b
resetMask typ ms = cont $ \k -> AResetMask typ True ms $ k () resetMask typ ms = cont $ \k -> AResetMask typ True ms $ k ()
threads' = goto a tid (mask m tid (cThreads ctx))
in simple threads' $ SetMasking False m
threads' = goto a tid (mask m tid threads)
-- | Reset the masking thread of the state. -- reset the masking thread of the state.
stepResetMask b1 b2 m c = simple threads' act where AResetMask b1 b2 m c ->
act = (if b1 then SetMasking else ResetMasking) b2 m let act = (if b1 then SetMasking else ResetMasking) b2 m
threads' = goto c tid (mask m tid threads) threads' = goto c tid (mask m tid (cThreads ctx))
in simple threads' act
-- | Create a new @MVar@, using the next 'MVarId'. -- execute a 'return' or 'pure'.
stepNewVar n c = do AReturn c -> simple (goto c tid (cThreads ctx)) Return
let (idSource', newmvid) = nextMVId n idSource
ref <- newRef Nothing
let mvar = MVar newmvid ref
return $ Right (goto (c mvar) tid threads, idSource', NewVar newmvid, wb, caps)
-- | Create a new @CRef@, using the next 'CRefId'. -- add a message to the trace.
stepNewRef n a c = do AMessage m c -> simple (goto c tid (cThreads ctx)) (Message m)
let (idSource', newcrid) = nextCRId n idSource
ref <- newRef (M.empty, 0, a)
let cref = CRef newcrid ref
return $ Right (goto (c cref) tid threads, idSource', NewRef newcrid, wb, caps)
-- | Lift an action from the underlying monad into the @Conc@ -- kill the current thread.
-- computation. AStop na -> na >> simple (kill tid (cThreads ctx)) Stop
stepLift na = do
a <- na
simple (goto a tid threads) LiftIO
-- | Execute a 'return' or 'pure'. -- run a subconcurrent computation.
stepReturn c = simple (goto c tid threads) Return ASub ma c
| M.size (cThreads ctx) > 1 -> pure (Left IllegalSubconcurrency)
| otherwise -> do
(res, g', trace) <- runConcurrency sched memtype (cSchedState ctx) ma
pure $ Right (ctx { cThreads = goto (c res) tid (cThreads ctx), cSchedState = g' }, Left (Subconcurrency, trace))
where
-- | Add a message to the trace. -- this is not inline in the long @case@ above as it's needed by
stepMessage m c = simple (goto c tid threads) (Message m) -- @AAtom@, @AThrow@, and @AThrowTo@.
stepThrow e =
case propagate (toException e) tid (cThreads ctx) of
Just threads' -> simple threads' Throw
Nothing -> pure $ Left UncaughtException
-- | Kill the current thread. -- helper for actions which only change the threads.
stepStop na = na >> simple (kill tid threads) Stop simple threads' act = pure $ Right (ctx { cThreads = threads' }, Right act)
-- | Helper for actions which don't touch the 'IdSource' or -- helper for actions impose a write barrier.
-- 'WriteBuffer'
simple threads' act = return $ Right (threads', idSource, act, wb, caps)
-- | Helper for actions impose a write barrier.
synchronised ma = do synchronised ma = do
writeBarrier wb writeBarrier (cWriteBuf ctx)
res <- ma res <- ma
return $ case res of return $ case res of
Right (threads', idSource', act', _, caps') -> Right (threads', idSource', act', emptyBuffer, caps') Right (ctx', act) -> Right (ctx' { cWriteBuf = emptyBuffer }, act)
_ -> res _ -> res

View File

@ -18,6 +18,7 @@ import Data.Dynamic (Dynamic)
import Data.Map.Strict (Map) import Data.Map.Strict (Map)
import Data.List.NonEmpty (NonEmpty, fromList) import Data.List.NonEmpty (NonEmpty, fromList)
import Test.DejaFu.Common import Test.DejaFu.Common
import Test.DejaFu.STM (STMLike)
{-# ANN module ("HLint: ignore Use record patterns" :: String) #-} {-# ANN module ("HLint: ignore Use record patterns" :: String) #-}
@ -32,16 +33,16 @@ import Test.DejaFu.Common
-- current expression of threads and exception handlers very difficult -- current expression of threads and exception handlers very difficult
-- (perhaps even not possible without significant reworking), so I -- (perhaps even not possible without significant reworking), so I
-- abandoned the attempt. -- abandoned the attempt.
newtype M n r s a = M { runM :: (a -> Action n r s) -> Action n r s } newtype M n r a = M { runM :: (a -> Action n r) -> Action n r }
instance Functor (M n r s) where instance Functor (M n r) where
fmap f m = M $ \ c -> runM m (c . f) fmap f m = M $ \ c -> runM m (c . f)
instance Applicative (M n r s) where instance Applicative (M n r) where
pure x = M $ \c -> AReturn $ c x pure x = M $ \c -> AReturn $ c x
f <*> v = M $ \c -> runM f (\g -> runM v (c . g)) f <*> v = M $ \c -> runM f (\g -> runM v (c . g))
instance Monad (M n r s) where instance Monad (M n r) where
return = pure return = pure
m >>= k = M $ \c -> runM m (\x -> runM (k x) c) m >>= k = M $ \c -> runM m (\x -> runM (k x) c)
@ -82,11 +83,11 @@ data Ticket a = Ticket
} }
-- | Construct a continuation-passing operation from a function. -- | Construct a continuation-passing operation from a function.
cont :: ((a -> Action n r s) -> Action n r s) -> M n r s a cont :: ((a -> Action n r) -> Action n r) -> M n r a
cont = M cont = M
-- | Run a CPS computation with the given final computation. -- | Run a CPS computation with the given final computation.
runCont :: M n r s a -> (a -> Action n r s) -> Action n r s runCont :: M n r a -> (a -> Action n r) -> Action n r
runCont = runM runCont = runM
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
@ -96,49 +97,51 @@ runCont = runM
-- only occur as a result of an action, and they cover (most of) the -- only occur as a result of an action, and they cover (most of) the
-- primitives of the concurrency. 'spawn' is absent as it is -- primitives of the concurrency. 'spawn' is absent as it is
-- implemented in terms of 'newEmptyMVar', 'fork', and 'putMVar'. -- implemented in terms of 'newEmptyMVar', 'fork', and 'putMVar'.
data Action n r s = data Action n r =
AFork String ((forall b. M n r s b -> M n r s b) -> Action n r s) (ThreadId -> Action n r s) AFork String ((forall b. M n r b -> M n r b) -> Action n r) (ThreadId -> Action n r)
| AMyTId (ThreadId -> Action n r s) | AMyTId (ThreadId -> Action n r)
| AGetNumCapabilities (Int -> Action n r s) | AGetNumCapabilities (Int -> Action n r)
| ASetNumCapabilities Int (Action n r s) | ASetNumCapabilities Int (Action n r)
| forall a. ANewVar String (MVar r a -> Action n r s) | forall a. ANewVar String (MVar r a -> Action n r)
| forall a. APutVar (MVar r a) a (Action n r s) | forall a. APutVar (MVar r a) a (Action n r)
| forall a. ATryPutVar (MVar r a) a (Bool -> Action n r s) | forall a. ATryPutVar (MVar r a) a (Bool -> Action n r)
| forall a. AReadVar (MVar r a) (a -> Action n r s) | forall a. AReadVar (MVar r a) (a -> Action n r)
| forall a. ATakeVar (MVar r a) (a -> Action n r s) | forall a. ATakeVar (MVar r a) (a -> Action n r)
| forall a. ATryTakeVar (MVar r a) (Maybe a -> Action n r s) | forall a. ATryTakeVar (MVar r a) (Maybe a -> Action n r)
| forall a. ANewRef String a (CRef r a -> Action n r s) | forall a. ANewRef String a (CRef r a -> Action n r)
| forall a. AReadRef (CRef r a) (a -> Action n r s) | forall a. AReadRef (CRef r a) (a -> Action n r)
| forall a. AReadRefCas (CRef r a) (Ticket a -> Action n r s) | forall a. AReadRefCas (CRef r a) (Ticket a -> Action n r)
| forall a b. AModRef (CRef r a) (a -> (a, b)) (b -> Action n r s) | forall a b. AModRef (CRef r a) (a -> (a, b)) (b -> Action n r)
| forall a b. AModRefCas (CRef r a) (a -> (a, b)) (b -> Action n r s) | forall a b. AModRefCas (CRef r a) (a -> (a, b)) (b -> Action n r)
| forall a. AWriteRef (CRef r a) a (Action n r s) | forall a. AWriteRef (CRef r a) a (Action n r)
| forall a. ACasRef (CRef r a) (Ticket a) a ((Bool, Ticket a) -> Action n r s) | forall a. ACasRef (CRef r a) (Ticket a) a ((Bool, Ticket a) -> Action n r)
| forall e. Exception e => AThrow e | forall e. Exception e => AThrow e
| forall e. Exception e => AThrowTo ThreadId e (Action n r s) | forall e. Exception e => AThrowTo ThreadId e (Action n r)
| forall a e. Exception e => ACatching (e -> M n r s a) (M n r s a) (a -> Action n r s) | forall a e. Exception e => ACatching (e -> M n r a) (M n r a) (a -> Action n r)
| APopCatching (Action n r s) | APopCatching (Action n r)
| forall a. AMasking MaskingState ((forall b. M n r s b -> M n r s b) -> M n r s a) (a -> Action n r s) | forall a. AMasking MaskingState ((forall b. M n r b -> M n r b) -> M n r a) (a -> Action n r)
| AResetMask Bool Bool MaskingState (Action n r s) | AResetMask Bool Bool MaskingState (Action n r)
| AMessage Dynamic (Action n r s) | AMessage Dynamic (Action n r)
| forall a. AAtom (s a) (a -> Action n r s) | forall a. AAtom (STMLike n r a) (a -> Action n r)
| ALift (n (Action n r s)) | ALift (n (Action n r))
| AYield (Action n r s) | AYield (Action n r)
| AReturn (Action n r s) | AReturn (Action n r)
| ACommit ThreadId CRefId | ACommit ThreadId CRefId
| AStop (n ()) | AStop (n ())
| forall a. ASub (M n r a) (Either Failure a -> Action n r)
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
-- * Scheduling & Traces -- * Scheduling & Traces
-- | Look as far ahead in the given continuation as possible. -- | Look as far ahead in the given continuation as possible.
lookahead :: Action n r s -> NonEmpty Lookahead lookahead :: Action n r -> NonEmpty Lookahead
lookahead = fromList . lookahead' where lookahead = fromList . lookahead' where
lookahead' (AFork _ _ _) = [WillFork] lookahead' (AFork _ _ _) = [WillFork]
lookahead' (AMyTId _) = [WillMyThreadId] lookahead' (AMyTId _) = [WillMyThreadId]
@ -170,3 +173,4 @@ lookahead = fromList . lookahead' where
lookahead' (AYield k) = WillYield : lookahead' k lookahead' (AYield k) = WillYield : lookahead' k
lookahead' (AReturn k) = WillReturn : lookahead' k lookahead' (AReturn k) = WillReturn : lookahead' k
lookahead' (AStop _) = [WillStop] lookahead' (AStop _) = [WillStop]
lookahead' (ASub _ _) = [WillSubconcurrency]

View File

@ -126,7 +126,7 @@ writeBarrier (WriteBuffer wb) = mapM_ flush $ M.elems wb where
flush = mapM_ $ \(BufferedWrite _ cref a) -> writeImmediate cref a flush = mapM_ $ \(BufferedWrite _ cref a) -> writeImmediate cref a
-- | Add phantom threads to the thread list to commit pending writes. -- | Add phantom threads to the thread list to commit pending writes.
addCommitThreads :: WriteBuffer r -> Threads n r s -> Threads n r s addCommitThreads :: WriteBuffer r -> Threads n r -> Threads n r
addCommitThreads (WriteBuffer wb) ts = ts <> M.fromList phantoms where addCommitThreads (WriteBuffer wb) ts = ts <> M.fromList phantoms where
phantoms = [ (ThreadId Nothing $ negate tid, mkthread $ fromJust c) phantoms = [ (ThreadId Nothing $ negate tid, mkthread $ fromJust c)
| ((k, b), tid) <- zip (M.toList wb) [1..] | ((k, b), tid) <- zip (M.toList wb) [1..]
@ -136,41 +136,41 @@ addCommitThreads (WriteBuffer wb) ts = ts <> M.fromList phantoms where
go EmptyL = Nothing go EmptyL = Nothing
-- | Remove phantom threads. -- | Remove phantom threads.
delCommitThreads :: Threads n r s -> Threads n r s delCommitThreads :: Threads n r -> Threads n r
delCommitThreads = M.filterWithKey $ \k _ -> k >= initialThread delCommitThreads = M.filterWithKey $ \k _ -> k >= initialThread
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
-- * Manipulating @MVar@s -- * Manipulating @MVar@s
-- | Put into a @MVar@, blocking if full. -- | Put into a @MVar@, blocking if full.
putIntoMVar :: MonadRef r n => MVar r a -> a -> Action n r s putIntoMVar :: MonadRef r n => MVar r a -> a -> Action n r
-> ThreadId -> Threads n r s -> n (Bool, Threads n r s, [ThreadId]) -> ThreadId -> Threads n r -> n (Bool, Threads n r, [ThreadId])
putIntoMVar cvar a c = mutMVar True cvar a (const c) putIntoMVar cvar a c = mutMVar True cvar a (const c)
-- | Try to put into a @MVar@, not blocking if full. -- | Try to put into a @MVar@, not blocking if full.
tryPutIntoMVar :: MonadRef r n => MVar r a -> a -> (Bool -> Action n r s) tryPutIntoMVar :: MonadRef r n => MVar r a -> a -> (Bool -> Action n r)
-> ThreadId -> Threads n r s -> n (Bool, Threads n r s, [ThreadId]) -> ThreadId -> Threads n r -> n (Bool, Threads n r, [ThreadId])
tryPutIntoMVar = mutMVar False tryPutIntoMVar = mutMVar False
-- | Read from a @MVar@, blocking if empty. -- | Read from a @MVar@, blocking if empty.
readFromMVar :: MonadRef r n => MVar r a -> (a -> Action n r s) readFromMVar :: MonadRef r n => MVar r a -> (a -> Action n r)
-> ThreadId -> Threads n r s -> n (Bool, Threads n r s, [ThreadId]) -> ThreadId -> Threads n r -> n (Bool, Threads n r, [ThreadId])
readFromMVar cvar c = seeMVar False True cvar (c . fromJust) readFromMVar cvar c = seeMVar False True cvar (c . fromJust)
-- | Take from a @MVar@, blocking if empty. -- | Take from a @MVar@, blocking if empty.
takeFromMVar :: MonadRef r n => MVar r a -> (a -> Action n r s) takeFromMVar :: MonadRef r n => MVar r a -> (a -> Action n r)
-> ThreadId -> Threads n r s -> n (Bool, Threads n r s, [ThreadId]) -> ThreadId -> Threads n r -> n (Bool, Threads n r, [ThreadId])
takeFromMVar cvar c = seeMVar True True cvar (c . fromJust) takeFromMVar cvar c = seeMVar True True cvar (c . fromJust)
-- | Try to take from a @MVar@, not blocking if empty. -- | Try to take from a @MVar@, not blocking if empty.
tryTakeFromMVar :: MonadRef r n => MVar r a -> (Maybe a -> Action n r s) tryTakeFromMVar :: MonadRef r n => MVar r a -> (Maybe a -> Action n r)
-> ThreadId -> Threads n r s -> n (Bool, Threads n r s, [ThreadId]) -> ThreadId -> Threads n r -> n (Bool, Threads n r, [ThreadId])
tryTakeFromMVar = seeMVar True False tryTakeFromMVar = seeMVar True False
-- | Mutate a @MVar@, in either a blocking or nonblocking way. -- | Mutate a @MVar@, in either a blocking or nonblocking way.
mutMVar :: MonadRef r n mutMVar :: MonadRef r n
=> Bool -> MVar r a -> a -> (Bool -> Action n r s) => Bool -> MVar r a -> a -> (Bool -> Action n r)
-> ThreadId -> Threads n r s -> n (Bool, Threads n r s, [ThreadId]) -> ThreadId -> Threads n r -> n (Bool, Threads n r, [ThreadId])
mutMVar blocking (MVar cvid ref) a c threadid threads = do mutMVar blocking (MVar cvid ref) a c threadid threads = do
val <- readRef ref val <- readRef ref
@ -191,8 +191,8 @@ mutMVar blocking (MVar cvid ref) a c threadid threads = do
-- | Read a @MVar@, in either a blocking or nonblocking -- | Read a @MVar@, in either a blocking or nonblocking
-- way. -- way.
seeMVar :: MonadRef r n seeMVar :: MonadRef r n
=> Bool -> Bool -> MVar r a -> (Maybe a -> Action n r s) => Bool -> Bool -> MVar r a -> (Maybe a -> Action n r)
-> ThreadId -> Threads n r s -> n (Bool, Threads n r s, [ThreadId]) -> ThreadId -> Threads n r -> n (Bool, Threads n r, [ThreadId])
seeMVar emptying blocking (MVar cvid ref) c threadid threads = do seeMVar emptying blocking (MVar cvid ref) c threadid threads = do
val <- readRef ref val <- readRef ref

View File

@ -27,22 +27,22 @@ import qualified Data.Map.Strict as M
-- * Threads -- * Threads
-- | Threads are stored in a map index by 'ThreadId'. -- | Threads are stored in a map index by 'ThreadId'.
type Threads n r s = Map ThreadId (Thread n r s) type Threads n r = Map ThreadId (Thread n r)
-- | All the state of a thread. -- | All the state of a thread.
data Thread n r s = Thread data Thread n r = Thread
{ _continuation :: Action n r s { _continuation :: Action n r
-- ^ The next action to execute. -- ^ The next action to execute.
, _blocking :: Maybe BlockedOn , _blocking :: Maybe BlockedOn
-- ^ The state of any blocks. -- ^ The state of any blocks.
, _handlers :: [Handler n r s] , _handlers :: [Handler n r]
-- ^ Stack of exception handlers -- ^ Stack of exception handlers
, _masking :: MaskingState , _masking :: MaskingState
-- ^ The exception masking state. -- ^ The exception masking state.
} }
-- | Construct a thread with just one action -- | Construct a thread with just one action
mkthread :: Action n r s -> Thread n r s mkthread :: Action n r -> Thread n r
mkthread c = Thread c Nothing [] Unmasked mkthread c = Thread c Nothing [] Unmasked
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
@ -53,7 +53,7 @@ mkthread c = Thread c Nothing [] Unmasked
data BlockedOn = OnMVarFull MVarId | OnMVarEmpty MVarId | OnTVar [TVarId] | OnMask ThreadId deriving Eq data BlockedOn = OnMVarFull MVarId | OnMVarEmpty MVarId | OnTVar [TVarId] | OnMask ThreadId deriving Eq
-- | Determine if a thread is blocked in a certain way. -- | Determine if a thread is blocked in a certain way.
(~=) :: Thread n r s -> BlockedOn -> Bool (~=) :: Thread n r -> BlockedOn -> Bool
thread ~= theblock = case (_blocking thread, theblock) of thread ~= theblock = case (_blocking thread, theblock) of
(Just (OnMVarFull _), OnMVarFull _) -> True (Just (OnMVarFull _), OnMVarFull _) -> True
(Just (OnMVarEmpty _), OnMVarEmpty _) -> True (Just (OnMVarEmpty _), OnMVarEmpty _) -> True
@ -65,11 +65,11 @@ thread ~= theblock = case (_blocking thread, theblock) of
-- * Exceptions -- * Exceptions
-- | An exception handler. -- | An exception handler.
data Handler n r s = forall e. Exception e => Handler (e -> Action n r s) data Handler n r = forall e. Exception e => Handler (e -> Action n r)
-- | Propagate an exception upwards, finding the closest handler -- | Propagate an exception upwards, finding the closest handler
-- which can deal with it. -- which can deal with it.
propagate :: SomeException -> ThreadId -> Threads n r s -> Maybe (Threads n r s) propagate :: SomeException -> ThreadId -> Threads n r -> Maybe (Threads n r)
propagate e tid threads = case M.lookup tid threads >>= go . _handlers of propagate e tid threads = case M.lookup tid threads >>= go . _handlers of
Just (act, hs) -> Just $ except act hs tid threads Just (act, hs) -> Just $ except act hs tid threads
Nothing -> Nothing Nothing -> Nothing
@ -79,40 +79,40 @@ propagate e tid threads = case M.lookup tid threads >>= go . _handlers of
go (Handler h:hs) = maybe (go hs) (\act -> Just (act, hs)) $ h <$> fromException e go (Handler h:hs) = maybe (go hs) (\act -> Just (act, hs)) $ h <$> fromException e
-- | Check if a thread can be interrupted by an exception. -- | Check if a thread can be interrupted by an exception.
interruptible :: Thread n r s -> Bool interruptible :: Thread n r -> Bool
interruptible thread = _masking thread == Unmasked || (_masking thread == MaskedInterruptible && isJust (_blocking thread)) interruptible thread = _masking thread == Unmasked || (_masking thread == MaskedInterruptible && isJust (_blocking thread))
-- | Register a new exception handler. -- | Register a new exception handler.
catching :: Exception e => (e -> Action n r s) -> ThreadId -> Threads n r s -> Threads n r s catching :: Exception e => (e -> Action n r) -> ThreadId -> Threads n r -> Threads n r
catching h = M.alter $ \(Just thread) -> Just $ thread { _handlers = Handler h : _handlers thread } catching h = M.alter $ \(Just thread) -> Just $ thread { _handlers = Handler h : _handlers thread }
-- | Remove the most recent exception handler. -- | Remove the most recent exception handler.
uncatching :: ThreadId -> Threads n r s -> Threads n r s uncatching :: ThreadId -> Threads n r -> Threads n r
uncatching = M.alter $ \(Just thread) -> Just $ thread { _handlers = tail $ _handlers thread } uncatching = M.alter $ \(Just thread) -> Just $ thread { _handlers = tail $ _handlers thread }
-- | Raise an exception in a thread. -- | Raise an exception in a thread.
except :: Action n r s -> [Handler n r s] -> ThreadId -> Threads n r s -> Threads n r s except :: Action n r -> [Handler n r] -> ThreadId -> Threads n r -> Threads n r
except act hs = M.alter $ \(Just thread) -> Just $ thread { _continuation = act, _handlers = hs, _blocking = Nothing } except act hs = M.alter $ \(Just thread) -> Just $ thread { _continuation = act, _handlers = hs, _blocking = Nothing }
-- | Set the masking state of a thread. -- | Set the masking state of a thread.
mask :: MaskingState -> ThreadId -> Threads n r s -> Threads n r s mask :: MaskingState -> ThreadId -> Threads n r -> Threads n r
mask ms = M.alter $ \(Just thread) -> Just $ thread { _masking = ms } mask ms = M.alter $ \(Just thread) -> Just $ thread { _masking = ms }
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
-- * Manipulating threads -- * Manipulating threads
-- | Replace the @Action@ of a thread. -- | Replace the @Action@ of a thread.
goto :: Action n r s -> ThreadId -> Threads n r s -> Threads n r s goto :: Action n r -> ThreadId -> Threads n r -> Threads n r
goto a = M.alter $ \(Just thread) -> Just (thread { _continuation = a }) goto a = M.alter $ \(Just thread) -> Just (thread { _continuation = a })
-- | Start a thread with the given ID, inheriting the masking state -- | Start a thread with the given ID, inheriting the masking state
-- from the parent thread. This ID must not already be in use! -- from the parent thread. This ID must not already be in use!
launch :: ThreadId -> ThreadId -> ((forall b. M n r s b -> M n r s b) -> Action n r s) -> Threads n r s -> Threads n r s launch :: ThreadId -> ThreadId -> ((forall b. M n r b -> M n r b) -> Action n r) -> Threads n r -> Threads n r
launch parent tid a threads = launch' ms tid a threads where launch parent tid a threads = launch' ms tid a threads where
ms = fromMaybe Unmasked $ _masking <$> M.lookup parent threads ms = fromMaybe Unmasked $ _masking <$> M.lookup parent threads
-- | Start a thread with the given ID and masking state. This must not already be in use! -- | Start a thread with the given ID and masking state. This must not already be in use!
launch' :: MaskingState -> ThreadId -> ((forall b. M n r s b -> M n r s b) -> Action n r s) -> Threads n r s -> Threads n r s launch' :: MaskingState -> ThreadId -> ((forall b. M n r b -> M n r b) -> Action n r) -> Threads n r -> Threads n r
launch' ms tid a = M.insert tid thread where launch' ms tid a = M.insert tid thread where
thread = Thread { _continuation = a umask, _blocking = Nothing, _handlers = [], _masking = ms } thread = Thread { _continuation = a umask, _blocking = Nothing, _handlers = [], _masking = ms }
@ -120,11 +120,11 @@ launch' ms tid a = M.insert tid thread where
resetMask typ m = cont $ \k -> AResetMask typ True m $ k () resetMask typ m = cont $ \k -> AResetMask typ True m $ k ()
-- | Kill a thread. -- | Kill a thread.
kill :: ThreadId -> Threads n r s -> Threads n r s kill :: ThreadId -> Threads n r -> Threads n r
kill = M.delete kill = M.delete
-- | Block a thread. -- | Block a thread.
block :: BlockedOn -> ThreadId -> Threads n r s -> Threads n r s block :: BlockedOn -> ThreadId -> Threads n r -> Threads n r
block blockedOn = M.alter doBlock where block blockedOn = M.alter doBlock where
doBlock (Just thread) = Just $ thread { _blocking = Just blockedOn } doBlock (Just thread) = Just $ thread { _blocking = Just blockedOn }
doBlock _ = error "Invariant failure in 'block': thread does NOT exist!" doBlock _ = error "Invariant failure in 'block': thread does NOT exist!"
@ -132,7 +132,7 @@ block blockedOn = M.alter doBlock where
-- | Unblock all threads waiting on the appropriate block. For 'TVar' -- | Unblock all threads waiting on the appropriate block. For 'TVar'
-- blocks, this will wake all threads waiting on at least one of the -- blocks, this will wake all threads waiting on at least one of the
-- given 'TVar's. -- given 'TVar's.
wake :: BlockedOn -> Threads n r s -> (Threads n r s, [ThreadId]) wake :: BlockedOn -> Threads n r -> (Threads n r, [ThreadId])
wake blockedOn threads = (unblock <$> threads, M.keys $ M.filter isBlocked threads) where wake blockedOn threads = (unblock <$> threads, M.keys $ M.filter isBlocked threads) where
unblock thread unblock thread
| isBlocked thread = thread { _blocking = Nothing } | isBlocked thread = thread { _blocking = Nothing }

View File

@ -79,9 +79,8 @@ module Test.DejaFu.SCT
, sctLengthBound , sctLengthBound
) where ) where
import Control.DeepSeq (NFData(..))
import Control.Monad.Ref (MonadRef) import Control.Monad.Ref (MonadRef)
import Data.List (nub) import Data.List (foldl')
import qualified Data.Map.Strict as M import qualified Data.Map.Strict as M
import Data.Maybe (isJust, fromJust) import Data.Maybe (isJust, fromJust)
import qualified Data.Set as S import qualified Data.Set as S
@ -140,17 +139,16 @@ cBound (Bounds pb fb lb) =
-- --
-- If no bounds are enabled, just backtrack to the given point. -- If no bounds are enabled, just backtrack to the given point.
cBacktrack :: Bounds -> BacktrackFunc cBacktrack :: Bounds -> BacktrackFunc
cBacktrack (Bounds Nothing Nothing Nothing) bs i t = backtrackAt (const False) False bs i t cBacktrack (Bounds (Just _) _ _) = pBacktrack
cBacktrack (Bounds pb fb lb) bs i t = lBack . fBack $ pBack bs where cBacktrack (Bounds _ (Just _) _) = fBacktrack
pBack backs = if isJust pb then pBacktrack backs i t else backs cBacktrack (Bounds _ _ (Just _)) = lBacktrack
fBack backs = if isJust fb then fBacktrack backs i t else backs cBacktrack _ = backtrackAt (\_ _ -> False)
lBack backs = if isJust lb then lBacktrack backs i t else backs
------------------------------------------------------------------------------- -------------------------------------------------------------------------------
-- Pre-emption bounding -- Pre-emption bounding
newtype PreemptionBound = PreemptionBound Int newtype PreemptionBound = PreemptionBound Int
deriving (NFData, Enum, Eq, Ord, Num, Real, Integral, Read, Show) deriving (Enum, Eq, Ord, Num, Real, Integral, Read, Show)
-- | A sensible default preemption bound: 2. -- | A sensible default preemption bound: 2.
-- --
@ -181,31 +179,26 @@ pBound (PreemptionBound pb) ts dl = preEmpCount ts dl <= pb
-- the same state being reached multiple times, but is needed because -- the same state being reached multiple times, but is needed because
-- of the artificial dependency imposed by the bound. -- of the artificial dependency imposed by the bound.
pBacktrack :: BacktrackFunc pBacktrack :: BacktrackFunc
pBacktrack bs i tid = pBacktrack bs = backtrackAt (\_ _ -> False) bs . concatMap addConservative where
maybe id (\j' b -> backtrack True b j' tid) j $ backtrack False bs i tid addConservative o@(i, _, tid) = o : case conservative i of
Just j -> [(j, True, tid)]
Nothing -> []
where -- index of conservative point
-- Index of the conservative point conservative i = go (reverse (take (i-1) bs)) (i-1) where
j = goJ . reverse . pairs $ zip [0..i-1] bs where go _ (-1) = Nothing
goJ (((_,b1), (j',b2)):rest) go (b1:rest@(b2:_)) j
| bcktThreadid b1 /= bcktThreadid b2 | bcktThreadid b1 /= bcktThreadid b2
&& not (isCommitRef . snd $ bcktDecision b1) && not (isCommitRef $ bcktAction b1)
&& not (isCommitRef . snd $ bcktDecision b2) = Just j' && not (isCommitRef $ bcktAction b2) = Just j
| otherwise = goJ rest | otherwise = go rest (j-1)
goJ [] = Nothing go _ _ = Nothing
-- List of adjacent pairs
{-# INLINE pairs #-}
pairs = zip <*> tail
-- Add a backtracking point.
backtrack = backtrackAt $ const False
------------------------------------------------------------------------------- -------------------------------------------------------------------------------
-- Fair bounding -- Fair bounding
newtype FairBound = FairBound Int newtype FairBound = FairBound Int
deriving (NFData, Enum, Eq, Ord, Num, Real, Integral, Read, Show) deriving (Enum, Eq, Ord, Num, Real, Integral, Read, Show)
-- | A sensible default fair bound: 5. -- | A sensible default fair bound: 5.
-- --
@ -233,15 +226,15 @@ fBound (FairBound fb) ts (_, l) = maxYieldCountDiff ts l <= fb
-- | Add a backtrack point. If the thread isn't runnable, or performs -- | Add a backtrack point. If the thread isn't runnable, or performs
-- a release operation, add all runnable threads. -- a release operation, add all runnable threads.
fBacktrack :: BacktrackFunc fBacktrack :: BacktrackFunc
fBacktrack bs i t = backtrackAt check False bs i t where fBacktrack = backtrackAt check where
-- True if a release operation is performed. -- True if a release operation is performed.
check b = Just True == (willRelease <$> M.lookup t (bcktRunnable b)) check t b = Just True == (willRelease <$> M.lookup t (bcktRunnable b))
------------------------------------------------------------------------------- -------------------------------------------------------------------------------
-- Length bounding -- Length bounding
newtype LengthBound = LengthBound Int newtype LengthBound = LengthBound Int
deriving (NFData, Enum, Eq, Ord, Num, Real, Integral, Read, Show) deriving (Enum, Eq, Ord, Num, Real, Integral, Read, Show)
-- | A sensible default length bound: 250. -- | A sensible default length bound: 250.
-- --
@ -269,7 +262,7 @@ lBound (LengthBound lb) ts _ = length ts < lb
-- | Add a backtrack point. If the thread isn't runnable, add all -- | Add a backtrack point. If the thread isn't runnable, add all
-- runnable threads. -- runnable threads.
lBacktrack :: BacktrackFunc lBacktrack :: BacktrackFunc
lBacktrack = backtrackAt (const False) False lBacktrack = backtrackAt (\_ _ -> False)
------------------------------------------------------------------------------- -------------------------------------------------------------------------------
-- DPOR -- DPOR
@ -313,7 +306,7 @@ sctBounded memtype bf backtrack conc = go initialState where
if schedIgnore s if schedIgnore s
then go newDPOR then go newDPOR
else ((res, trace):) <$> go (pruneCommits $ addBacktracks bpoints newDPOR) else ((res, trace):) <$> go (addBacktracks bpoints newDPOR)
Nothing -> pure [] Nothing -> pure []
@ -332,32 +325,11 @@ sctBounded memtype bf backtrack conc = go initialState where
-- Incorporate the new backtracking steps into the DPOR tree. -- Incorporate the new backtracking steps into the DPOR tree.
addBacktracks = incorporateBacktrackSteps bf addBacktracks = incorporateBacktrackSteps bf
-------------------------------------------------------------------------------
-- Post-processing
-- | Remove commits from the todo sets where every other action will
-- result in a write barrier (and so a commit) occurring.
--
-- To get the benefit from this, do not execute commit actions from
-- the todo set until there are no other choises.
pruneCommits :: DPOR -> DPOR
pruneCommits bpor
| not onlycommits || not alldonesync = go bpor
| otherwise = go bpor { dporTodo = M.empty }
where
go b = b { dporDone = pruneCommits <$> dporDone bpor }
onlycommits = all (<initialThread) . M.keys $ dporTodo bpor
alldonesync = all barrier . M.elems $ dporDone bpor
barrier = isBarrier . simplifyAction . fromJust . dporAction
------------------------------------------------------------------------------- -------------------------------------------------------------------------------
-- Dependency function -- Dependency function
-- | Check if an action is dependent on another. -- | Check if an action is dependent on another.
dependent :: MemType -> DepState -> (ThreadId, ThreadAction) -> (ThreadId, ThreadAction) -> Bool dependent :: MemType -> DepState -> ThreadId -> ThreadAction -> ThreadId -> ThreadAction -> Bool
-- This is basically the same as 'dependent'', but can make use of the -- This is basically the same as 'dependent'', but can make use of the
-- additional information in a 'ThreadAction' to make different -- additional information in a 'ThreadAction' to make different
-- decisions in a few cases: -- decisions in a few cases:
@ -381,14 +353,14 @@ dependent :: MemType -> DepState -> (ThreadId, ThreadAction) -> (ThreadId, Threa
-- - Dependency of STM transactions can be /greatly/ improved here, -- - Dependency of STM transactions can be /greatly/ improved here,
-- as the 'Lookahead' does not know which @TVar@s will be touched, -- as the 'Lookahead' does not know which @TVar@s will be touched,
-- and so has to assume all transactions are dependent. -- and so has to assume all transactions are dependent.
dependent _ _ (_, SetNumCapabilities a) (_, GetNumCapabilities b) = a /= b dependent _ _ _ (SetNumCapabilities a) _ (GetNumCapabilities b) = a /= b
dependent _ ds (_, ThrowTo t) (t2, a) = t == t2 && canInterrupt ds t2 a dependent _ ds _ (ThrowTo t) t2 a = t == t2 && canInterrupt ds t2 a
dependent memtype ds (t1, a1) (t2, a2) = case rewind a2 of dependent memtype ds t1 a1 t2 a2 = case rewind a2 of
Just l2 Just l2
| isSTM a1 && isSTM a2 | isSTM a1 && isSTM a2
-> not . S.null $ tvarsOf a1 `S.intersection` tvarsOf a2 -> not . S.null $ tvarsOf a1 `S.intersection` tvarsOf a2
| not (isBlock a1 && isBarrier (simplifyLookahead l2)) -> | not (isBlock a1 && isBarrier (simplifyLookahead l2)) ->
dependent' memtype ds (t1, a1) (t2, l2) dependent' memtype ds t1 a1 t2 l2
_ -> dependentActions memtype ds (simplifyAction a1) (simplifyAction a2) _ -> dependentActions memtype ds (simplifyAction a1) (simplifyAction a2)
where where
@ -400,8 +372,8 @@ dependent memtype ds (t1, a1) (t2, a2) = case rewind a2 of
-- --
-- Termination of the initial thread is handled specially in the DPOR -- Termination of the initial thread is handled specially in the DPOR
-- implementation. -- implementation.
dependent' :: MemType -> DepState -> (ThreadId, ThreadAction) -> (ThreadId, Lookahead) -> Bool dependent' :: MemType -> DepState -> ThreadId -> ThreadAction -> ThreadId -> Lookahead -> Bool
dependent' memtype ds (t1, a1) (t2, l2) = case (a1, l2) of dependent' memtype ds t1 a1 t2 l2 = case (a1, l2) of
-- Worst-case assumption: all IO is dependent. -- Worst-case assumption: all IO is dependent.
(LiftIO, WillLiftIO) -> True (LiftIO, WillLiftIO) -> True
@ -496,6 +468,7 @@ yieldCount tid ts l = go initialThread ts where
| t == tid && willYield l = 1 | t == tid && willYield l = 1
| otherwise = 0 | otherwise = 0
{-# INLINE go' #-}
go' t t' act rest go' t t' act rest
| t == tid && didYield act = 1 + go t' rest | t == tid && didYield act = 1 + go t' rest
| otherwise = go t' rest | otherwise = go t' rest
@ -505,10 +478,14 @@ yieldCount tid ts l = go initialThread ts where
maxYieldCountDiff :: [(Decision, ThreadAction)] maxYieldCountDiff :: [(Decision, ThreadAction)]
-> Lookahead -> Lookahead
-> Int -> Int
maxYieldCountDiff ts l = maximum yieldCountDiffs where maxYieldCountDiff ts l = go 0 yieldCounts where
yieldsBy tid = yieldCount tid ts l go m (yc:ycs) =
yieldCounts = [yieldsBy tid | tid <- nub $ allTids ts] let m' = m `max` foldl' (go' yc) 0 ycs
yieldCountDiffs = [y1 - y2 | y1 <- yieldCounts, y2 <- yieldCounts] in go m' ycs
go m [] = m
go' yc0 m yc = m `max` abs (yc0 - yc)
yieldCounts = [yieldCount t ts l | t <- allTids ts]
-- All the threads created during the lifetime of the system. -- All the threads created during the lifetime of the system.
allTids ((_, act):rest) = allTids ((_, act):rest) =

View File

@ -11,14 +11,14 @@
-- interface of this library. -- interface of this library.
module Test.DejaFu.SCT.Internal where module Test.DejaFu.SCT.Internal where
import Control.DeepSeq (NFData(..), force)
import Control.Exception (MaskingState(..)) import Control.Exception (MaskingState(..))
import Data.Char (ord) import Data.Char (ord)
import Data.List (foldl', intercalate, partition, sortBy) import Data.Function (on)
import qualified Data.Foldable as F
import Data.List (intercalate, nubBy, partition, sortOn)
import Data.List.NonEmpty (NonEmpty(..), toList) import Data.List.NonEmpty (NonEmpty(..), toList)
import Data.Ord (Down(..), comparing)
import Data.Map.Strict (Map) import Data.Map.Strict (Map)
import Data.Maybe (fromJust, isJust, isNothing, mapMaybe) import Data.Maybe (catMaybes, fromJust, isNothing)
import qualified Data.Map.Strict as M import qualified Data.Map.Strict as M
import Data.Set (Set) import Data.Set (Set)
import qualified Data.Set as S import qualified Data.Set as S
@ -51,16 +51,7 @@ data DPOR = DPOR
, dporAction :: Maybe ThreadAction , dporAction :: Maybe ThreadAction
-- ^ What happened at this step. This will be 'Nothing' at the root, -- ^ What happened at this step. This will be 'Nothing' at the root,
-- 'Just' everywhere else. -- 'Just' everywhere else.
} } deriving Show
instance NFData DPOR where
rnf dpor = rnf ( dporRunnable dpor
, dporTodo dpor
, dporDone dpor
, dporSleep dpor
, dporTaken dpor
, dporAction dpor
)
-- | One step of the execution, including information for backtracking -- | One step of the execution, including information for backtracking
-- purposes. This backtracking information is used to generate new -- purposes. This backtracking information is used to generate new
@ -68,7 +59,9 @@ instance NFData DPOR where
data BacktrackStep = BacktrackStep data BacktrackStep = BacktrackStep
{ bcktThreadid :: ThreadId { bcktThreadid :: ThreadId
-- ^ The thread running at this step -- ^ The thread running at this step
, bcktDecision :: (Decision, ThreadAction) , bcktDecision :: Decision
-- ^ What was decided at this step.
, bcktAction :: ThreadAction
-- ^ What happened at this step. -- ^ What happened at this step.
, bcktRunnable :: Map ThreadId Lookahead , bcktRunnable :: Map ThreadId Lookahead
-- ^ The threads runnable at this step -- ^ The threads runnable at this step
@ -77,15 +70,7 @@ data BacktrackStep = BacktrackStep
-- alternatives were added conservatively due to the bound. -- alternatives were added conservatively due to the bound.
, bcktState :: DepState , bcktState :: DepState
-- ^ Some domain-specific state at this point. -- ^ Some domain-specific state at this point.
} } deriving Show
instance NFData BacktrackStep where
rnf b = rnf ( bcktThreadid b
, bcktDecision b
, bcktRunnable b
, bcktBacktracks b
, bcktState b
)
-- | Initial DPOR state, given an initial thread ID. This initial -- | Initial DPOR state, given an initial thread ID. This initial
-- thread should exist and be runnable at the start of execution. -- thread should exist and be runnable at the start of execution.
@ -127,24 +112,17 @@ findSchedulePrefix predicate idx dpor0
(ts, c, slp) = allPrefixes !! i (ts, c, slp) = allPrefixes !! i
in Just (ts, c, slp, g) in Just (ts, c, slp, g)
where where
allPrefixes = go (initialDPORThread dpor0) dpor0 allPrefixes = go dpor0
go tid dpor = go dpor =
-- All the possible prefix traces from this point, with -- All the possible prefix traces from this point, with
-- updated DPOR subtrees if taken from the done list. -- updated DPOR subtrees if taken from the done list.
let prefixes = concatMap go' (M.toList $ dporDone dpor) ++ here dpor let prefixes = here dpor : map go' (M.toList $ dporDone dpor)
-- Sort by number of preemptions, in descending order. in case concatPartition (\(t:_,_,_) -> predicate t) prefixes of
cmp = Down . preEmps tid dpor . (\(a,_,_) -> a)
sorted = sortBy (comparing cmp) prefixes
in if null prefixes
then []
else case partition (\(t:_,_,_) -> predicate t) sorted of
([], []) -> err "findSchedulePrefix" "empty prefix list!"
([], choices) -> choices ([], choices) -> choices
(choices, _) -> choices (choices, _) -> choices
go' (tid, dpor) = (\(ts,c,slp) -> (tid:ts,c,slp)) <$> go tid dpor go' (tid, dpor) = (\(ts,c,slp) -> (tid:ts,c,slp)) <$> go dpor
-- Prefix traces terminating with a to-do decision at this point. -- Prefix traces terminating with a to-do decision at this point.
here dpor = [([t], c, sleeps dpor) | (t, c) <- M.toList $ dporTodo dpor] here dpor = [([t], c, sleeps dpor) | (t, c) <- M.toList $ dporTodo dpor]
@ -154,16 +132,10 @@ findSchedulePrefix predicate idx dpor0
-- explored. -- explored.
sleeps dpor = dporSleep dpor `M.union` dporTaken dpor sleeps dpor = dporSleep dpor `M.union` dporTaken dpor
-- The number of pre-emptive context switches
preEmps tid dpor (t:ts) =
let rest = preEmps t (fromJust . M.lookup t $ dporDone dpor) ts
in if tid `S.member` dporRunnable dpor then 1 + rest else rest
preEmps _ _ [] = 0::Int
-- | Add a new trace to the tree, creating a new subtree branching off -- | Add a new trace to the tree, creating a new subtree branching off
-- at the point where the \"to-do\" decision was made. -- at the point where the \"to-do\" decision was made.
incorporateTrace incorporateTrace
:: (DepState -> (ThreadId, ThreadAction) -> (ThreadId, ThreadAction) -> Bool) :: (DepState -> ThreadId -> ThreadAction -> ThreadId -> ThreadAction -> Bool)
-- ^ Dependency function -- ^ Dependency function
-> Bool -> Bool
-- ^ Whether the \"to-do\" point which was used to create this new -- ^ Whether the \"to-do\" point which was used to create this new
@ -176,7 +148,7 @@ incorporateTrace
incorporateTrace dependency conservative trace dpor0 = grow initialDepState (initialDPORThread dpor0) trace dpor0 where incorporateTrace dependency conservative trace dpor0 = grow initialDepState (initialDPORThread dpor0) trace dpor0 where
grow state tid trc@((d, _, a):rest) dpor = grow state tid trc@((d, _, a):rest) dpor =
let tid' = tidOf tid d let tid' = tidOf tid d
state' = updateDepState state (tid', a) state' = updateDepState state tid' a
in case M.lookup tid' (dporDone dpor) of in case M.lookup tid' (dporDone dpor) of
Just dpor' -> Just dpor' ->
let done = M.insert tid' (grow state' tid' rest dpor') (dporDone dpor) let done = M.insert tid' (grow state' tid' rest dpor') (dporDone dpor)
@ -193,8 +165,8 @@ incorporateTrace dependency conservative trace dpor0 = grow initialDepState (ini
-- Construct a new subtree corresponding to a trace suffix. -- Construct a new subtree corresponding to a trace suffix.
subtree state tid sleep ((_, _, a):rest) = subtree state tid sleep ((_, _, a):rest) =
let state' = updateDepState state (tid, a) let state' = updateDepState state tid a
sleep' = M.filterWithKey (\t a' -> not $ dependency state' (tid, a) (t,a')) sleep sleep' = M.filterWithKey (\t a' -> not $ dependency state' tid a t a') sleep
in DPOR in DPOR
{ dporRunnable = S.fromList $ case rest of { dporRunnable = S.fromList $ case rest of
((_, runnable, _):_) -> map fst runnable ((_, runnable, _):_) -> map fst runnable
@ -225,7 +197,7 @@ incorporateTrace dependency conservative trace dpor0 = grow initialDepState (ini
-- runnable, a dependency is imposed between this final action and -- runnable, a dependency is imposed between this final action and
-- everything else. -- everything else.
findBacktrackSteps findBacktrackSteps
:: (DepState -> (ThreadId, ThreadAction) -> (ThreadId, Lookahead) -> Bool) :: (DepState -> ThreadId -> ThreadAction -> ThreadId -> Lookahead -> Bool)
-- ^ Dependency function. -- ^ Dependency function.
-> BacktrackFunc -> BacktrackFunc
-- ^ Backtracking function. Given a list of backtracking points, and -- ^ Backtracking function. Given a list of backtracking points, and
@ -244,17 +216,16 @@ findBacktrackSteps
-> Trace -> Trace
-- ^ The execution trace. -- ^ The execution trace.
-> [BacktrackStep] -> [BacktrackStep]
findBacktrackSteps _ _ _ bcktrck findBacktrackSteps dependency backtrack boundKill = go initialDepState S.empty initialThread [] . F.toList where
| Sq.null bcktrck = const []
findBacktrackSteps dependency backtrack boundKill bcktrck = go initialDepState S.empty initialThread [] (Sq.viewl bcktrck) where
-- Walk through the traces one step at a time, building up a list of -- Walk through the traces one step at a time, building up a list of
-- new backtracking points. -- new backtracking points.
go state allThreads tid bs ((e,i):<is) ((d,_,a):ts) = go state allThreads tid bs ((e,i):is) ((d,_,a):ts) =
let tid' = tidOf tid d let tid' = tidOf tid d
state' = updateDepState state (tid', a) state' = updateDepState state tid' a
this = BacktrackStep this = BacktrackStep
{ bcktThreadid = tid' { bcktThreadid = tid'
, bcktDecision = (d, a) , bcktDecision = d
, bcktAction = a
, bcktRunnable = M.fromList . toList $ e , bcktRunnable = M.fromList . toList $ e
, bcktBacktracks = M.fromList $ map (\i' -> (i', False)) i , bcktBacktracks = M.fromList $ map (\i' -> (i', False)) i
, bcktState = state' , bcktState = state'
@ -263,30 +234,41 @@ findBacktrackSteps dependency backtrack boundKill bcktrck = go initialDepState S
runnable = S.fromList (M.keys $ bcktRunnable this) runnable = S.fromList (M.keys $ bcktRunnable this)
allThreads' = allThreads `S.union` runnable allThreads' = allThreads `S.union` runnable
killsEarly = null ts && boundKill killsEarly = null ts && boundKill
in go state' allThreads' tid' bs' (Sq.viewl is) ts in go state' allThreads' tid' bs' is ts
go _ _ _ bs _ _ = bs go _ _ _ bs _ _ = bs
-- Find the prior actions dependent with this one and add -- Find the prior actions dependent with this one and add
-- backtracking points. -- backtracking points.
doBacktrack killsEarly allThreads enabledThreads bs = doBacktrack killsEarly allThreads enabledThreads bs =
let tagged = reverse $ zip [0..] bs let tagged = reverse $ zip [0..] bs
idxs = [ (head is, u) idxs = [ (head is, False, u)
| (u, n) <- enabledThreads | (u, n) <- enabledThreads
, v <- S.toList allThreads , v <- S.toList allThreads
, u /= v , u /= v
, let is = idxs' u n v tagged , let is = idxs' u n v tagged
, not $ null is] , not $ null is]
idxs' u n v = mapMaybe go' where idxs' u n v = catMaybes . go' True where
go' (i, b) {-# INLINE go' #-}
go' final ((i,b):rest)
-- Don't cross subconcurrency boundaries
| isSubC final b = []
-- If this is the final action in the trace and the -- If this is the final action in the trace and the
-- execution was killed due to nothing being within bounds -- execution was killed due to nothing being within bounds
-- (@killsEarly == True@) assume worst-case dependency. -- (@killsEarly == True@) assume worst-case dependency.
| bcktThreadid b == v && (killsEarly || isDependent b) = Just i | bcktThreadid b == v && (killsEarly || isDependent b) = Just i : go' False rest
| otherwise = Nothing | otherwise = go' False rest
go' _ [] = []
isDependent b = dependency (bcktState b) (bcktThreadid b, snd $ bcktDecision b) (u, n) {-# INLINE isSubC #-}
in foldl' (\b (i, u) -> backtrack b i u) bs idxs isSubC final b = case bcktAction b of
Stop -> not final && bcktThreadid b == initialThread
Subconcurrency -> bcktThreadid b == initialThread
_ -> False
{-# INLINE isDependent #-}
isDependent b = dependency (bcktState b) (bcktThreadid b) (bcktAction b) u n
in backtrack bs idxs
-- | Add new backtracking points, if they have not already been -- | Add new backtracking points, if they have not already been
-- visited, fit into the bound, and aren't in the sleep set. -- visited, fit into the bound, and aren't in the sleep set.
@ -302,10 +284,9 @@ incorporateBacktrackSteps bv = go Nothing [] where
go priorTid pref (b:bs) bpor = go priorTid pref (b:bs) bpor =
let bpor' = doBacktrack priorTid pref b bpor let bpor' = doBacktrack priorTid pref b bpor
tid = bcktThreadid b tid = bcktThreadid b
pref' = pref ++ [bcktDecision b] pref' = pref ++ [(bcktDecision b, bcktAction b)]
child = go (Just tid) pref' bs . fromJust $ M.lookup tid (dporDone bpor) child = go (Just tid) pref' bs . fromJust $ M.lookup tid (dporDone bpor)
in bpor' { dporDone = M.insert tid child $ dporDone bpor' } in bpor' { dporDone = M.insert tid child $ dporDone bpor' }
go _ _ [] bpor = bpor go _ _ [] bpor = bpor
doBacktrack priorTid pref b bpor = doBacktrack priorTid pref b bpor =
@ -343,16 +324,7 @@ data SchedState = SchedState
, schedDepState :: DepState , schedDepState :: DepState
-- ^ State used by the dependency function to determine when to -- ^ State used by the dependency function to determine when to
-- remove decisions from the sleep set. -- remove decisions from the sleep set.
} } deriving Show
instance NFData SchedState where
rnf s = rnf ( schedSleep s
, schedPrefix s
, schedBPoints s
, schedIgnore s
, schedBoundKill s
, schedDepState s
)
-- | Initial scheduler state for a given prefix -- | Initial scheduler state for a given prefix
initialSchedState :: Map ThreadId ThreadAction initialSchedState :: Map ThreadId ThreadAction
@ -378,52 +350,60 @@ type BoundFunc
-- | A backtracking step is a point in the execution where another -- | A backtracking step is a point in the execution where another
-- decision needs to be made, in order to explore interesting new -- decision needs to be made, in order to explore interesting new
-- schedules. A backtracking /function/ takes the steps identified so -- schedules. A backtracking /function/ takes the steps identified so
-- far and a point and a thread to backtrack to, and inserts at least -- far and a list of points and thread at that point to backtrack
-- that backtracking point. More may be added to compensate for the -- to. More points be added to compensate for the effects of the
-- effects of the bounding function. For example, under pre-emption -- bounding function. For example, under pre-emption bounding a
-- bounding a conservative backtracking point is added at the prior -- conservative backtracking point is added at the prior context
-- context switch. -- switch. The bool is whether the point is conservative. Conservative
-- points are always explored, whereas non-conservative ones might be
-- skipped based on future information.
-- --
-- In general, a backtracking function should identify one or more -- In general, a backtracking function should identify one or more
-- backtracking points, and then use @backtrackAt@ to do the actual -- backtracking points, and then use @backtrackAt@ to do the actual
-- work. -- work.
type BacktrackFunc type BacktrackFunc
= [BacktrackStep] -> Int -> ThreadId -> [BacktrackStep] = [BacktrackStep] -> [(Int, Bool, ThreadId)] -> [BacktrackStep]
-- | Add a backtracking point. If the thread isn't runnable, add all -- | Add a backtracking point. If the thread isn't runnable, add all
-- runnable threads. If the backtracking point is already present, -- runnable threads. If the backtracking point is already present,
-- don't re-add it UNLESS this would make it conservative. -- don't re-add it UNLESS this would make it conservative.
backtrackAt backtrackAt
:: (BacktrackStep -> Bool) :: (ThreadId -> BacktrackStep -> Bool)
-- ^ If this returns @True@, backtrack to all runnable threads, -- ^ If this returns @True@, backtrack to all runnable threads,
-- rather than just the given thread. -- rather than just the given thread.
-> Bool
-- ^ Is this backtracking point conservative? Conservative points
-- are always explored, whereas non-conservative ones might be
-- skipped based on future information.
-> BacktrackFunc -> BacktrackFunc
backtrackAt toAll conservative bs i tid = go bs i where backtrackAt toAll bs0 = backtrackAt' . nubBy ((==) `on` fst') . sortOn fst' where
go bx@(b:rest) 0 fst' (x,_,_) = x
backtrackAt' ((i,c,t):is) = go i bs0 i c t is
backtrackAt' [] = bs0
go i0 (b:bs) 0 c tid is
-- If the backtracking point is already present, don't re-add it, -- If the backtracking point is already present, don't re-add it,
-- UNLESS this would force it to backtrack (it's conservative) -- UNLESS this would force it to backtrack (it's conservative)
-- where before it might not. -- where before it might not.
| not (toAll b) && tid `M.member` bcktRunnable b = | not (toAll tid b) && tid `M.member` bcktRunnable b =
let val = M.lookup tid $ bcktBacktracks b let val = M.lookup tid $ bcktBacktracks b
in if isNothing val || (val == Just False && conservative) b' = if isNothing val || (val == Just False && c)
then b { bcktBacktracks = backtrackTo b } : rest then b { bcktBacktracks = backtrackTo tid c b }
else bx else b
in b' : case is of
((i',c',t'):is') -> go i' bs (i'-i0-1) c' t' is'
[] -> bs
-- Otherwise just backtrack to everything runnable. -- Otherwise just backtrack to everything runnable.
| otherwise = b { bcktBacktracks = backtrackAll b } : rest | otherwise =
let b' = b { bcktBacktracks = backtrackAll c b }
go (b:rest) n = b : go rest (n-1) in b' : case is of
go [] _ = error "backtrackAt: Ran out of schedule whilst backtracking!" ((i',c',t'):is') -> go i' bs (i'-i0-1) c' t' is'
[] -> bs
go i0 (b:bs) i c tid is = b : go i0 bs (i-1) c tid is
go _ [] _ _ _ _ = err "backtrackAt" "ran out of schedule whilst backtracking!"
-- Backtrack to a single thread -- Backtrack to a single thread
backtrackTo = M.insert tid conservative . bcktBacktracks backtrackTo tid c = M.insert tid c . bcktBacktracks
-- Backtrack to all runnable threads -- Backtrack to all runnable threads
backtrackAll = M.map (const conservative) . bcktRunnable backtrackAll c = M.map (const c) . bcktRunnable
-- | DPOR scheduler: takes a list of decisions, and maintains a trace -- | DPOR scheduler: takes a list of decisions, and maintains a trace
-- including the runnable threads, and the alternative choices allowed -- including the runnable threads, and the alternative choices allowed
@ -433,17 +413,14 @@ backtrackAt toAll conservative bs i tid = go bs i where
-- the prior thread if it's (1) still runnable and (2) hasn't just -- the prior thread if it's (1) still runnable and (2) hasn't just
-- yielded. Furthermore, threads which /will/ yield are ignored in -- yielded. Furthermore, threads which /will/ yield are ignored in
-- preference of those which will not. -- preference of those which will not.
--
-- This forces full evaluation of the result every step, to avoid any
-- possible space leaks.
dporSched dporSched
:: (DepState -> (ThreadId, ThreadAction) -> (ThreadId, ThreadAction) -> Bool) :: (DepState -> ThreadId -> ThreadAction -> ThreadId -> ThreadAction -> Bool)
-- ^ Dependency function. -- ^ Dependency function.
-> BoundFunc -> BoundFunc
-- ^ Bound function: returns true if that schedule prefix terminated -- ^ Bound function: returns true if that schedule prefix terminated
-- with the lookahead decision fits within the bound. -- with the lookahead decision fits within the bound.
-> Scheduler SchedState -> Scheduler SchedState
dporSched dependency inBound trc prior threads s = force schedule where dporSched dependency inBound trc prior threads s = schedule where
-- Pick a thread to run. -- Pick a thread to run.
schedule = case schedPrefix s of schedule = case schedPrefix s of
-- If there is a decision available, make it -- If there is a decision available, make it
@ -455,7 +432,7 @@ dporSched dependency inBound trc prior threads s = force schedule where
[] -> [] ->
let choices = restrictToBound initialise let choices = restrictToBound initialise
checkDep t a = case prior of checkDep t a = case prior of
Just (tid, act) -> dependency (schedDepState s) (tid, act) (t, a) Just (tid, act) -> dependency (schedDepState s) tid act t a
Nothing -> False Nothing -> False
ssleep' = M.filterWithKey (\t a -> not $ checkDep t a) $ schedSleep s ssleep' = M.filterWithKey (\t a -> not $ checkDep t a) $ schedSleep s
choices' = filter (`notElem` M.keys ssleep') choices choices' = filter (`notElem` M.keys ssleep') choices
@ -470,7 +447,7 @@ dporSched dependency inBound trc prior threads s = force schedule where
{ schedBPoints = schedBPoints s |> (threads, rest) { schedBPoints = schedBPoints s |> (threads, rest)
, schedDepState = nextDepState , schedDepState = nextDepState
} }
nextDepState = let ds = schedDepState s in maybe ds (updateDepState ds) prior nextDepState = let ds = schedDepState s in maybe ds (uncurry $ updateDepState ds) prior
-- Pick a new thread to run, not considering bounds. Choose the -- Pick a new thread to run, not considering bounds. Choose the
-- current thread if available and it hasn't just yielded, otherwise -- current thread if available and it hasn't just yielded, otherwise
@ -537,11 +514,7 @@ data DepState = DepState
-- the masking state is assumed to be @Unmasked@. This nicely -- the masking state is assumed to be @Unmasked@. This nicely
-- provides compatibility with dpor-0.1, where the thread IDs are -- provides compatibility with dpor-0.1, where the thread IDs are
-- not available. -- not available.
} } deriving (Eq, Show)
instance NFData DepState where
-- Cheats: 'MaskingState' has no 'NFData' instance.
rnf ds = rnf (depCRState ds, M.keys (depMaskState ds))
-- | Initial dependency state. -- | Initial dependency state.
initialDepState :: DepState initialDepState :: DepState
@ -549,8 +522,8 @@ initialDepState = DepState M.empty M.empty
-- | Update the 'CRef' buffer state with the action that has just -- | Update the 'CRef' buffer state with the action that has just
-- happened. -- happened.
updateDepState :: DepState -> (ThreadId, ThreadAction) -> DepState updateDepState :: DepState -> ThreadId -> ThreadAction -> DepState
updateDepState depstate (tid, act) = DepState updateDepState depstate tid act = DepState
{ depCRState = updateCRState act $ depCRState depstate { depCRState = updateCRState act $ depCRState depstate
, depMaskState = updateMaskState tid act $ depMaskState depstate , depMaskState = updateMaskState tid act $ depMaskState depstate
} }
@ -698,3 +671,17 @@ toDotFiltered check showTid showAct = digraph . go "L" where
-- | Internal errors. -- | Internal errors.
err :: String -> String -> a err :: String -> String -> a
err func msg = error (func ++ ": (internal error) " ++ msg) err func msg = error (func ++ ": (internal error) " ++ msg)
-- | A combination of 'partition' and 'concat'.
concatPartition :: (a -> Bool) -> [[a]] -> ([a], [a])
{-# INLINE concatPartition #-}
-- note: `foldr (flip (foldr select))` is slow, as is `foldl (foldl
-- select))`, and `foldl'` variants. The sweet spot seems to be `foldl
-- (foldr select)` for some reason I don't really understand.
concatPartition p = foldl (foldr select) ([], []) where
-- Lazy pattern matching, got this trick from the 'partition'
-- implementation. This reduces allocation fairly significantly; I
-- do not know why.
select a ~(ts, fs)
| p a = (a:ts, fs)
| otherwise = (ts, a:fs)

View File

@ -96,7 +96,6 @@ library
build-depends: base >=4.8 && <5 build-depends: base >=4.8 && <5
, concurrency >=1.0 && <1.1 , concurrency >=1.0 && <1.1
, containers >=0.5 && <0.6 , containers >=0.5 && <0.6
, deepseq >=1.3 && <1.5
, exceptions >=0.7 && <0.9 , exceptions >=0.7 && <0.9
, monad-loops >=0.4 && <0.5 , monad-loops >=0.4 && <0.5
, mtl >=2.2 && <2.3 , mtl >=2.2 && <2.3