mirror of
https://github.com/barrucadu/dejafu.git
synced 2024-12-23 21:42:09 +03:00
Merge branch 'subconcurrency'
This commit is contained in:
commit
75d2b6ca73
@ -11,6 +11,7 @@ import Test.HUnit.DejaFu (testDejafu)
|
||||
|
||||
import Control.Concurrent.Classy
|
||||
import Control.Monad.STM.Class
|
||||
import Test.DejaFu.Conc (Conc, subconcurrency)
|
||||
|
||||
#if __GLASGOW_HASKELL__ < 710
|
||||
import Control.Applicative ((<$>), (<*>))
|
||||
@ -49,6 +50,13 @@ tests =
|
||||
, testGroup "Daemons" . hUnitTestToTests $ test
|
||||
[ testDejafu schedDaemon "schedule daemon" $ gives' [0,1]
|
||||
]
|
||||
|
||||
, testGroup "Subconcurrency" . hUnitTestToTests $ test
|
||||
[ testDejafu scDeadlock1 "deadlock1" $ gives' [Left Deadlock, Right ()]
|
||||
, testDejafu scDeadlock2 "deadlock2" $ gives' [(Left Deadlock, ()), (Right (), ())]
|
||||
, testDejafu scSuccess "success" $ gives' [Right ()]
|
||||
, testDejafu scIllegal "illegal" $ gives [Left IllegalSubconcurrency]
|
||||
]
|
||||
]
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
@ -207,3 +215,40 @@ schedDaemon = do
|
||||
x <- newCRef 0
|
||||
_ <- fork $ myThreadId >> writeCRef x 1
|
||||
readCRef x
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- Subconcurrency
|
||||
|
||||
-- | Subcomputation deadlocks sometimes.
|
||||
scDeadlock1 :: Monad n => Conc n r (Either Failure ())
|
||||
scDeadlock1 = do
|
||||
var <- newEmptyMVar
|
||||
subconcurrency $ do
|
||||
void . fork $ putMVar var ()
|
||||
putMVar var ()
|
||||
|
||||
-- | Subcomputation deadlocks sometimes, and action after it still
|
||||
-- happens.
|
||||
scDeadlock2 :: Monad n => Conc n r (Either Failure (), ())
|
||||
scDeadlock2 = do
|
||||
var <- newEmptyMVar
|
||||
res <- subconcurrency $ do
|
||||
void . fork $ putMVar var ()
|
||||
putMVar var ()
|
||||
(,) <$> pure res <*> readMVar var
|
||||
|
||||
-- | Subcomputation successfully completes.
|
||||
scSuccess :: Monad n => Conc n r (Either Failure ())
|
||||
scSuccess = do
|
||||
var <- newMVar ()
|
||||
subconcurrency $ do
|
||||
out <- newEmptyMVar
|
||||
void . fork $ takeMVar var >>= putMVar out
|
||||
takeMVar out
|
||||
|
||||
-- | Illegal usage
|
||||
scIllegal :: Monad n => Conc n r ()
|
||||
scIllegal = do
|
||||
var <- newEmptyMVar
|
||||
void . fork $ readMVar var
|
||||
void . subconcurrency $ pure ()
|
||||
|
@ -3,6 +3,7 @@
|
||||
module Cases.SingleThreaded where
|
||||
|
||||
import Control.Exception (ArithException(..), ArrayException(..))
|
||||
import Control.Monad (void)
|
||||
import Test.DejaFu (Failure(..), gives, gives')
|
||||
import Test.Framework (Test, testGroup)
|
||||
import Test.Framework.Providers.HUnit (hUnitTestToTests)
|
||||
@ -10,7 +11,7 @@ import Test.HUnit (test)
|
||||
import Test.HUnit.DejaFu (testDejafu)
|
||||
|
||||
import Control.Concurrent.Classy
|
||||
import Control.Monad.STM.Class
|
||||
import Test.DejaFu.Conc (Conc, subconcurrency)
|
||||
|
||||
import Utils
|
||||
|
||||
@ -58,6 +59,12 @@ tests =
|
||||
[ testDejafu capsGet "get" $ gives' [True]
|
||||
, testDejafu capsSet "set" $ gives' [True]
|
||||
]
|
||||
|
||||
, testGroup "Subconcurrency" . hUnitTestToTests $ test
|
||||
[ testDejafu scDeadlock1 "deadlock1" $ gives' [Left Deadlock]
|
||||
, testDejafu scDeadlock2 "deadlock2" $ gives' [(Left Deadlock, ())]
|
||||
, testDejafu scSuccess "success" $ gives' [Right ()]
|
||||
]
|
||||
]
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
@ -252,3 +259,22 @@ capsSet = do
|
||||
caps <- getNumCapabilities
|
||||
setNumCapabilities $ caps + 1
|
||||
(== caps + 1) <$> getNumCapabilities
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- Subconcurrency
|
||||
|
||||
-- | Subcomputation deadlocks.
|
||||
scDeadlock1 :: Monad n => Conc n r (Either Failure ())
|
||||
scDeadlock1 = subconcurrency (newEmptyMVar >>= readMVar)
|
||||
|
||||
-- | Subcomputation deadlocks, and action after it still happens.
|
||||
scDeadlock2 :: Monad n => Conc n r (Either Failure (), ())
|
||||
scDeadlock2 = do
|
||||
var <- newMVar ()
|
||||
(,) <$> subconcurrency (putMVar var ()) <*> readMVar var
|
||||
|
||||
-- | Subcomputation successfully completes.
|
||||
scSuccess :: Monad n => Conc n r (Either Failure ())
|
||||
scSuccess = do
|
||||
var <- newMVar ()
|
||||
subconcurrency (takeMVar var)
|
||||
|
@ -240,7 +240,6 @@ module Test.DejaFu
|
||||
) where
|
||||
|
||||
import Control.Arrow (first)
|
||||
import Control.DeepSeq (NFData(..))
|
||||
import Control.Monad (when, unless)
|
||||
import Control.Monad.Ref (MonadRef)
|
||||
import Control.Monad.ST (runST)
|
||||
@ -400,9 +399,6 @@ defaultFail failures = Result False 0 failures ""
|
||||
defaultPass :: Result a
|
||||
defaultPass = Result True 0 [] ""
|
||||
|
||||
instance NFData a => NFData (Result a) where
|
||||
rnf r = rnf (_pass r, _casesChecked r, _failures r, _failureMsg r)
|
||||
|
||||
instance Functor Result where
|
||||
fmap f r = r { _failures = map (first $ fmap f) $ _failures r }
|
||||
|
||||
|
@ -60,7 +60,6 @@ module Test.DejaFu.Common
|
||||
, MemType(..)
|
||||
) where
|
||||
|
||||
import Control.DeepSeq (NFData(..))
|
||||
import Control.Exception (MaskingState(..))
|
||||
import Data.Dynamic (Dynamic)
|
||||
import Data.List (sort, nub, intercalate)
|
||||
@ -83,9 +82,6 @@ instance Show ThreadId where
|
||||
show (ThreadId (Just n) _) = n
|
||||
show (ThreadId Nothing i) = show i
|
||||
|
||||
instance NFData ThreadId where
|
||||
rnf (ThreadId n i) = rnf (n, i)
|
||||
|
||||
-- | Every @CRef@ has a unique identifier.
|
||||
data CRefId = CRefId (Maybe String) Int
|
||||
deriving Eq
|
||||
@ -97,9 +93,6 @@ instance Show CRefId where
|
||||
show (CRefId (Just n) _) = n
|
||||
show (CRefId Nothing i) = show i
|
||||
|
||||
instance NFData CRefId where
|
||||
rnf (CRefId n i) = rnf (n, i)
|
||||
|
||||
-- | Every @MVar@ has a unique identifier.
|
||||
data MVarId = MVarId (Maybe String) Int
|
||||
deriving Eq
|
||||
@ -111,9 +104,6 @@ instance Show MVarId where
|
||||
show (MVarId (Just n) _) = n
|
||||
show (MVarId Nothing i) = show i
|
||||
|
||||
instance NFData MVarId where
|
||||
rnf (MVarId n i) = rnf (n, i)
|
||||
|
||||
-- | Every @TVar@ has a unique identifier.
|
||||
data TVarId = TVarId (Maybe String) Int
|
||||
deriving Eq
|
||||
@ -125,9 +115,6 @@ instance Show TVarId where
|
||||
show (TVarId (Just n) _) = n
|
||||
show (TVarId Nothing i) = show i
|
||||
|
||||
instance NFData TVarId where
|
||||
rnf (TVarId n i) = rnf (n, i)
|
||||
|
||||
-- | The ID of the initial thread.
|
||||
initialThread :: ThreadId
|
||||
initialThread = ThreadId (Just "main") 0
|
||||
@ -272,38 +259,10 @@ data ThreadAction =
|
||||
-- ^ A '_concMessage' annotation was processed.
|
||||
| Stop
|
||||
-- ^ Cease execution and terminate.
|
||||
| Subconcurrency
|
||||
-- ^ Start executing an action with @subconcurrency@.
|
||||
deriving Show
|
||||
|
||||
instance NFData ThreadAction where
|
||||
rnf (Fork t) = rnf t
|
||||
rnf (GetNumCapabilities i) = rnf i
|
||||
rnf (SetNumCapabilities i) = rnf i
|
||||
rnf (NewVar c) = rnf c
|
||||
rnf (PutVar c ts) = rnf (c, ts)
|
||||
rnf (BlockedPutVar c) = rnf c
|
||||
rnf (TryPutVar c b ts) = rnf (c, b, ts)
|
||||
rnf (ReadVar c) = rnf c
|
||||
rnf (BlockedReadVar c) = rnf c
|
||||
rnf (TakeVar c ts) = rnf (c, ts)
|
||||
rnf (BlockedTakeVar c) = rnf c
|
||||
rnf (TryTakeVar c b ts) = rnf (c, b, ts)
|
||||
rnf (NewRef c) = rnf c
|
||||
rnf (ReadRef c) = rnf c
|
||||
rnf (ReadRefCas c) = rnf c
|
||||
rnf (ModRef c) = rnf c
|
||||
rnf (ModRefCas c) = rnf c
|
||||
rnf (WriteRef c) = rnf c
|
||||
rnf (CasRef c b) = rnf (c, b)
|
||||
rnf (CommitRef t c) = rnf (t, c)
|
||||
rnf (STM s ts) = rnf (s, ts)
|
||||
rnf (BlockedSTM s) = rnf s
|
||||
rnf (ThrowTo t) = rnf t
|
||||
rnf (BlockedThrowTo t) = rnf t
|
||||
rnf (SetMasking b m) = b `seq` m `seq` ()
|
||||
rnf (ResetMasking b m) = b `seq` m `seq` ()
|
||||
rnf (Message m) = m `seq` ()
|
||||
rnf a = a `seq` ()
|
||||
|
||||
-- | Check if a @ThreadAction@ immediately blocks.
|
||||
isBlock :: ThreadAction -> Bool
|
||||
isBlock (BlockedThrowTo _) = True
|
||||
@ -401,28 +360,10 @@ data Lookahead =
|
||||
-- ^ Will process a _concMessage' annotation.
|
||||
| WillStop
|
||||
-- ^ Will cease execution and terminate.
|
||||
| WillSubconcurrency
|
||||
-- ^ Will execute an action with @subconcurrency@.
|
||||
deriving Show
|
||||
|
||||
instance NFData Lookahead where
|
||||
rnf (WillSetNumCapabilities i) = rnf i
|
||||
rnf (WillPutVar c) = rnf c
|
||||
rnf (WillTryPutVar c) = rnf c
|
||||
rnf (WillReadVar c) = rnf c
|
||||
rnf (WillTakeVar c) = rnf c
|
||||
rnf (WillTryTakeVar c) = rnf c
|
||||
rnf (WillReadRef c) = rnf c
|
||||
rnf (WillReadRefCas c) = rnf c
|
||||
rnf (WillModRef c) = rnf c
|
||||
rnf (WillModRefCas c) = rnf c
|
||||
rnf (WillWriteRef c) = rnf c
|
||||
rnf (WillCasRef c) = rnf c
|
||||
rnf (WillCommitRef t c) = rnf (t, c)
|
||||
rnf (WillThrowTo t) = rnf t
|
||||
rnf (WillSetMasking b m) = b `seq` m `seq` ()
|
||||
rnf (WillResetMasking b m) = b `seq` m `seq` ()
|
||||
rnf (WillMessage m) = m `seq` ()
|
||||
rnf l = l `seq` ()
|
||||
|
||||
-- | Convert a 'ThreadAction' into a 'Lookahead': \"rewind\" what has
|
||||
-- happened. 'Killed' has no 'Lookahead' counterpart.
|
||||
rewind :: ThreadAction -> Maybe Lookahead
|
||||
@ -462,6 +403,7 @@ rewind LiftIO = Just WillLiftIO
|
||||
rewind Return = Just WillReturn
|
||||
rewind (Message m) = Just (WillMessage m)
|
||||
rewind Stop = Just WillStop
|
||||
rewind Subconcurrency = Just WillSubconcurrency
|
||||
|
||||
-- | Check if an operation could enable another thread.
|
||||
willRelease :: Lookahead -> Bool
|
||||
@ -508,17 +450,6 @@ data ActionType =
|
||||
-- communication.
|
||||
deriving (Eq, Show)
|
||||
|
||||
instance NFData ActionType where
|
||||
rnf (UnsynchronisedRead r) = rnf r
|
||||
rnf (UnsynchronisedWrite r) = rnf r
|
||||
rnf (PartiallySynchronisedCommit r) = rnf r
|
||||
rnf (PartiallySynchronisedWrite r) = rnf r
|
||||
rnf (PartiallySynchronisedModify r) = rnf r
|
||||
rnf (SynchronisedModify r) = rnf r
|
||||
rnf (SynchronisedRead c) = rnf c
|
||||
rnf (SynchronisedWrite c) = rnf c
|
||||
rnf a = a `seq` ()
|
||||
|
||||
-- | Check if an action imposes a write barrier.
|
||||
isBarrier :: ActionType -> Bool
|
||||
isBarrier (SynchronisedModify _) = True
|
||||
@ -611,13 +542,6 @@ data TAction =
|
||||
-- ^ Terminate successfully and commit effects.
|
||||
deriving (Eq, Show)
|
||||
|
||||
instance NFData TAction where
|
||||
rnf (TRead v) = rnf v
|
||||
rnf (TWrite v) = rnf v
|
||||
rnf (TCatch s m) = rnf (s, m)
|
||||
rnf (TOrElse s m) = rnf (s, m)
|
||||
rnf a = a `seq` ()
|
||||
|
||||
-------------------------------------------------------------------------------
|
||||
-- Traces
|
||||
|
||||
@ -640,11 +564,6 @@ data Decision =
|
||||
-- ^ Pre-empt the running thread, and switch to another.
|
||||
deriving (Eq, Show)
|
||||
|
||||
instance NFData Decision where
|
||||
rnf (Start tid) = rnf tid
|
||||
rnf (SwitchTo tid) = rnf tid
|
||||
rnf d = d `seq` ()
|
||||
|
||||
-- | Pretty-print a trace, including a key of the thread IDs (not
|
||||
-- including thread 0). Each line of the key is indented by two
|
||||
-- spaces.
|
||||
@ -675,15 +594,15 @@ showTrace trc = intercalate "\n" $ concatMap go trc : strkey where
|
||||
preEmpCount :: [(Decision, ThreadAction)]
|
||||
-> (Decision, Lookahead)
|
||||
-> Int
|
||||
preEmpCount ts (d, _) = go initialThread Nothing ts where
|
||||
go _ (Just Yield) ((SwitchTo t, a):rest) = go t (Just a) rest
|
||||
go tid prior ((SwitchTo t, a):rest)
|
||||
preEmpCount (x:xs) (d, _) = go initialThread x xs where
|
||||
go _ (_, Yield) (r@(SwitchTo t, _):rest) = go t r rest
|
||||
go tid prior (r@(SwitchTo t, _):rest)
|
||||
| isCommitThread t = go tid prior (skip rest)
|
||||
| otherwise = 1 + go t (Just a) rest
|
||||
go _ _ ((Start t, a):rest) = go t (Just a) rest
|
||||
go tid _ ((Continue, a):rest) = go tid (Just a) rest
|
||||
| otherwise = 1 + go t r rest
|
||||
go _ _ (r@(Start t, _):rest) = go t r rest
|
||||
go tid _ (r@(Continue, _):rest) = go tid r rest
|
||||
go _ prior [] = case (prior, d) of
|
||||
(Just Yield, SwitchTo _) -> 0
|
||||
((_, Yield), SwitchTo _) -> 0
|
||||
(_, SwitchTo _) -> 1
|
||||
_ -> 0
|
||||
|
||||
@ -694,6 +613,7 @@ preEmpCount ts (d, _) = go initialThread Nothing ts where
|
||||
skip = dropWhile (not . isContextSwitch . fst)
|
||||
isContextSwitch Continue = False
|
||||
isContextSwitch _ = True
|
||||
preEmpCount [] _ = 0
|
||||
|
||||
-------------------------------------------------------------------------------
|
||||
-- Failures
|
||||
@ -716,18 +636,19 @@ data Failure =
|
||||
-- ^ The computation became blocked indefinitely on @TVar@s.
|
||||
| UncaughtException
|
||||
-- ^ An uncaught exception bubbled to the top of the computation.
|
||||
| IllegalSubconcurrency
|
||||
-- ^ Calls to @subconcurrency@ were nested, or attempted when
|
||||
-- multiple threads existed.
|
||||
deriving (Eq, Show, Read, Ord, Enum, Bounded)
|
||||
|
||||
instance NFData Failure where
|
||||
rnf f = f `seq` ()
|
||||
|
||||
-- | Pretty-print a failure
|
||||
showFail :: Failure -> String
|
||||
showFail Abort = "[abort]"
|
||||
showFail Deadlock = "[deadlock]"
|
||||
showFail STMDeadlock = "[stm-deadlock]"
|
||||
showFail InternalError = "[internal-error]"
|
||||
showFail Abort = "[abort]"
|
||||
showFail Deadlock = "[deadlock]"
|
||||
showFail STMDeadlock = "[stm-deadlock]"
|
||||
showFail InternalError = "[internal-error]"
|
||||
showFail UncaughtException = "[exception]"
|
||||
showFail IllegalSubconcurrency = "[illegal-subconcurrency]"
|
||||
|
||||
-------------------------------------------------------------------------------
|
||||
-- Memory Models
|
||||
@ -751,9 +672,6 @@ data MemType =
|
||||
-- created.
|
||||
deriving (Eq, Show, Read, Ord, Enum, Bounded)
|
||||
|
||||
instance NFData MemType where
|
||||
rnf m = m `seq` ()
|
||||
|
||||
-------------------------------------------------------------------------------
|
||||
-- Utilities
|
||||
|
||||
|
@ -28,6 +28,7 @@ module Test.DejaFu.Conc
|
||||
, Failure(..)
|
||||
, MemType(..)
|
||||
, runConcurrent
|
||||
, subconcurrency
|
||||
|
||||
-- * Execution traces
|
||||
, Trace
|
||||
@ -49,12 +50,11 @@ import Control.Exception (MaskingState(..))
|
||||
import qualified Control.Monad.Base as Ba
|
||||
import qualified Control.Monad.Catch as Ca
|
||||
import qualified Control.Monad.IO.Class as IO
|
||||
import Control.Monad.Ref (MonadRef, newRef, readRef, writeRef)
|
||||
import Control.Monad.Ref (MonadRef,)
|
||||
import Control.Monad.ST (ST)
|
||||
import Data.Dynamic (toDyn)
|
||||
import qualified Data.Foldable as F
|
||||
import Data.IORef (IORef)
|
||||
import qualified Data.Map.Strict as M
|
||||
import Data.Maybe (fromJust)
|
||||
import Data.STRef (STRef)
|
||||
import Test.DejaFu.Schedule
|
||||
|
||||
@ -62,13 +62,12 @@ import qualified Control.Monad.Conc.Class as C
|
||||
import Test.DejaFu.Common
|
||||
import Test.DejaFu.Conc.Internal
|
||||
import Test.DejaFu.Conc.Internal.Common
|
||||
import Test.DejaFu.Conc.Internal.Threading
|
||||
import Test.DejaFu.STM
|
||||
|
||||
{-# ANN module ("HLint: ignore Avoid lambda" :: String) #-}
|
||||
{-# ANN module ("HLint: ignore Use const" :: String) #-}
|
||||
|
||||
newtype Conc n r a = C { unC :: M n r (STMLike n r) a } deriving (Functor, Applicative, Monad)
|
||||
newtype Conc n r a = C { unC :: M n r a } deriving (Functor, Applicative, Monad)
|
||||
|
||||
-- | A 'MonadConc' implementation using @ST@, this should be preferred
|
||||
-- if you do not need 'liftIO'.
|
||||
@ -77,10 +76,10 @@ type ConcST t = Conc (ST t) (STRef t)
|
||||
-- | A 'MonadConc' implementation using @IO@.
|
||||
type ConcIO = Conc IO IORef
|
||||
|
||||
toConc :: ((a -> Action n r (STMLike n r)) -> Action n r (STMLike n r)) -> Conc n r a
|
||||
toConc :: ((a -> Action n r) -> Action n r) -> Conc n r a
|
||||
toConc = C . cont
|
||||
|
||||
wrap :: (M n r (STMLike n r) a -> M n r (STMLike n r) a) -> Conc n r a -> Conc n r a
|
||||
wrap :: (M n r a -> M n r a) -> Conc n r a -> Conc n r a
|
||||
wrap f = C . f . unC
|
||||
|
||||
instance IO.MonadIO ConcIO where
|
||||
@ -181,20 +180,15 @@ runConcurrent :: MonadRef r n
|
||||
-> s
|
||||
-> Conc n r a
|
||||
-> n (Either Failure a, s, Trace)
|
||||
runConcurrent sched memtype s (C conc) = do
|
||||
ref <- newRef Nothing
|
||||
runConcurrent sched memtype s ma = do
|
||||
(res, s', trace) <- runConcurrency sched memtype s (unC ma)
|
||||
pure (res, s', F.toList trace)
|
||||
|
||||
let c = runCont conc (AStop . writeRef ref . Just . Right)
|
||||
let threads = launch' Unmasked initialThread (const c) M.empty
|
||||
|
||||
(s', trace) <- runThreads runTransaction
|
||||
sched
|
||||
memtype
|
||||
s
|
||||
threads
|
||||
initialIdSource
|
||||
ref
|
||||
|
||||
out <- readRef ref
|
||||
|
||||
pure (fromJust out, s', reverse trace)
|
||||
-- | Run a concurrent computation and return its result.
|
||||
--
|
||||
-- This can only be called in the main thread, when no other threads
|
||||
-- exist. Calls to 'subconcurrency' cannot be nested. Violating either
|
||||
-- of these conditions will result in the computation failing with
|
||||
-- @IllegalSubconcurrency@.
|
||||
subconcurrency :: Conc n r a -> Conc n r (Either Failure a)
|
||||
subconcurrency ma = toConc (ASub (unC ma))
|
||||
|
@ -16,19 +16,23 @@
|
||||
module Test.DejaFu.Conc.Internal where
|
||||
|
||||
import Control.Exception (MaskingState(..), toException)
|
||||
import Control.Monad.Ref (MonadRef, newRef, writeRef)
|
||||
import Control.Monad.Ref (MonadRef, newRef, readRef, writeRef)
|
||||
import qualified Data.Foldable as F
|
||||
import Data.Functor (void)
|
||||
import Data.List (sort)
|
||||
import Data.List.NonEmpty (NonEmpty(..), fromList)
|
||||
import qualified Data.Map.Strict as M
|
||||
import Data.Maybe (fromJust, isJust, isNothing, listToMaybe)
|
||||
import Data.Maybe (fromJust, isJust, isNothing)
|
||||
import Data.Monoid ((<>))
|
||||
import Data.Sequence (Seq, (<|))
|
||||
import qualified Data.Sequence as Seq
|
||||
|
||||
import Test.DejaFu.Common
|
||||
import Test.DejaFu.Conc.Internal.Common
|
||||
import Test.DejaFu.Conc.Internal.Memory
|
||||
import Test.DejaFu.Conc.Internal.Threading
|
||||
import Test.DejaFu.Schedule
|
||||
import Test.DejaFu.STM (Result(..))
|
||||
import Test.DejaFu.STM (Result(..), runTransaction)
|
||||
|
||||
{-# ANN module ("HLint: ignore Use record patterns" :: String) #-}
|
||||
{-# ANN module ("HLint: ignore Use const" :: String) #-}
|
||||
@ -36,40 +40,69 @@ import Test.DejaFu.STM (Result(..))
|
||||
--------------------------------------------------------------------------------
|
||||
-- * Execution
|
||||
|
||||
-- | 'Trace' but as a sequence.
|
||||
type SeqTrace
|
||||
= Seq (Decision, [(ThreadId, NonEmpty Lookahead)], ThreadAction)
|
||||
|
||||
-- | Run a concurrent computation with a given 'Scheduler' and initial
|
||||
-- state, returning a failure reason on error. Also returned is the
|
||||
-- final state of the scheduler, and an execution trace.
|
||||
runConcurrency :: MonadRef r n
|
||||
=> Scheduler g
|
||||
-> MemType
|
||||
-> g
|
||||
-> M n r a
|
||||
-> n (Either Failure a, g, SeqTrace)
|
||||
runConcurrency sched memtype g ma = do
|
||||
ref <- newRef Nothing
|
||||
|
||||
let c = runCont ma (AStop . writeRef ref . Just . Right)
|
||||
let threads = launch' Unmasked initialThread (const c) M.empty
|
||||
let ctx = Context { cSchedState = g, cIdSource = initialIdSource, cThreads = threads, cWriteBuf = emptyBuffer, cCaps = 2 }
|
||||
|
||||
(finalCtx, trace) <- runThreads sched memtype ref ctx
|
||||
out <- readRef ref
|
||||
pure (fromJust out, cSchedState finalCtx, trace)
|
||||
|
||||
-- | The context a collection of threads are running in.
|
||||
data Context n r g = Context
|
||||
{ cSchedState :: g
|
||||
, cIdSource :: IdSource
|
||||
, cThreads :: Threads n r
|
||||
, cWriteBuf :: WriteBuffer r
|
||||
, cCaps :: Int
|
||||
}
|
||||
|
||||
-- | Run a collection of threads, until there are no threads left.
|
||||
--
|
||||
-- Note: this returns the trace in reverse order, because it's more
|
||||
-- efficient to prepend to a list than append. As this function isn't
|
||||
-- exposed to users of the library, this is just an internal gotcha to
|
||||
-- watch out for.
|
||||
runThreads :: MonadRef r n => (forall x. s x -> IdSource -> n (Result x, IdSource, TTrace))
|
||||
-> Scheduler g -> MemType -> g -> Threads n r s -> IdSource -> r (Maybe (Either Failure a)) -> n (g, Trace)
|
||||
runThreads runstm sched memtype origg origthreads idsrc ref = go idsrc [] Nothing origg origthreads emptyBuffer 2 where
|
||||
go idSource sofar prior g threads wb caps
|
||||
| isTerminated = stop g
|
||||
| isDeadlocked = die g Deadlock
|
||||
| isSTMLocked = die g STMDeadlock
|
||||
| isAborted = die g' Abort
|
||||
| isNonexistant = die g' InternalError
|
||||
| isBlocked = die g' InternalError
|
||||
runThreads :: MonadRef r n
|
||||
=> Scheduler g -> MemType -> r (Maybe (Either Failure a)) -> Context n r g -> n (Context n r g, SeqTrace)
|
||||
runThreads sched memtype ref = go Seq.empty [] Nothing where
|
||||
-- sofar is the 'SeqTrace', sofarSched is the @[(Decision,
|
||||
-- ThreadAction)]@ trace the scheduler needs.
|
||||
go sofar sofarSched prior ctx
|
||||
| isTerminated = stop ctx
|
||||
| isDeadlocked = die Deadlock ctx
|
||||
| isSTMLocked = die STMDeadlock ctx
|
||||
| isAborted = die Abort $ ctx { cSchedState = g' }
|
||||
| isNonexistant = die InternalError $ ctx { cSchedState = g' }
|
||||
| isBlocked = die InternalError $ ctx { cSchedState = g' }
|
||||
| otherwise = do
|
||||
stepped <- stepThread runstm memtype (_continuation $ fromJust thread) idSource chosen threads wb caps
|
||||
stepped <- stepThread sched memtype chosen (_continuation $ fromJust thread) $ ctx { cSchedState = g' }
|
||||
case stepped of
|
||||
Right (threads', idSource', act, wb', caps') -> loop threads' idSource' act wb' caps'
|
||||
|
||||
Right (ctx', actOrTrc) -> loop actOrTrc ctx'
|
||||
Left UncaughtException
|
||||
| chosen == initialThread -> die g' UncaughtException
|
||||
| otherwise -> loop (kill chosen threads) idSource Killed wb caps
|
||||
|
||||
Left failure -> die g' failure
|
||||
| chosen == initialThread -> die UncaughtException $ ctx { cSchedState = g' }
|
||||
| otherwise -> loop (Right Killed) $ ctx { cThreads = kill chosen threadsc, cSchedState = g' }
|
||||
Left failure -> die failure $ ctx { cSchedState = g' }
|
||||
|
||||
where
|
||||
(choice, g') = sched (map (\(d,_,a) -> (d,a)) $ reverse sofar) ((\p (_,_,a) -> (p,a)) <$> prior <*> listToMaybe sofar) (fromList $ map (\(t,l:|_) -> (t,l)) runnable') g
|
||||
(choice, g') = sched sofarSched prior (fromList $ map (\(t,l:|_) -> (t,l)) runnable') (cSchedState ctx)
|
||||
chosen = fromJust choice
|
||||
runnable' = [(t, nextActions t) | t <- sort $ M.keys runnable]
|
||||
runnable = M.filter (isNothing . _blocking) threadsc
|
||||
thread = M.lookup chosen threadsc
|
||||
threadsc = addCommitThreads wb threads
|
||||
threadsc = addCommitThreads (cWriteBuf ctx) threads
|
||||
threads = cThreads ctx
|
||||
isAborted = isNothing choice
|
||||
isBlocked = isJust . _blocking $ fromJust thread
|
||||
isNonexistant = isNothing thread
|
||||
@ -87,299 +120,262 @@ runThreads runstm sched memtype origg origthreads idsrc ref = go idsrc [] Nothin
|
||||
_ -> thrd
|
||||
|
||||
decision
|
||||
| Just chosen == prior = Continue
|
||||
| prior `notElem` map (Just . fst) runnable' = Start chosen
|
||||
| Just chosen == (fst <$> prior) = Continue
|
||||
| (fst <$> prior) `notElem` map (Just . fst) runnable' = Start chosen
|
||||
| otherwise = SwitchTo chosen
|
||||
|
||||
nextActions t = lookahead . _continuation . fromJust $ M.lookup t threadsc
|
||||
|
||||
stop outg = pure (outg, sofar)
|
||||
die outg reason = writeRef ref (Just $ Left reason) >> stop outg
|
||||
stop finalCtx = pure (finalCtx, sofar)
|
||||
die reason finalCtx = writeRef ref (Just $ Left reason) >> stop finalCtx
|
||||
|
||||
loop threads' idSource' act wb' =
|
||||
let sofar' = ((decision, runnable', act) : sofar)
|
||||
threads'' = if (interruptible <$> M.lookup chosen threads') /= Just False then unblockWaitingOn chosen threads' else threads'
|
||||
in go idSource' sofar' (Just chosen) g' (delCommitThreads threads'') wb'
|
||||
loop trcOrAct ctx' =
|
||||
let (act, trc) = case trcOrAct of
|
||||
Left (a, as) -> (a, (decision, runnable', a) <| as)
|
||||
Right a -> (a, Seq.singleton (decision, runnable', a))
|
||||
threads' = if (interruptible <$> M.lookup chosen (cThreads ctx')) /= Just False
|
||||
then unblockWaitingOn chosen (cThreads ctx')
|
||||
else cThreads ctx'
|
||||
sofar' = sofar <> trc
|
||||
sofarSched' = sofarSched <> map (\(d,_,a) -> (d,a)) (F.toList trc)
|
||||
prior' = Just (chosen, act)
|
||||
in go sofar' sofarSched' prior' $ ctx' { cThreads = delCommitThreads threads' }
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- * Single-step execution
|
||||
|
||||
-- | Run a single thread one step, by dispatching on the type of
|
||||
-- 'Action'.
|
||||
stepThread :: forall n r s. MonadRef r n
|
||||
=> (forall x. s x -> IdSource -> n (Result x, IdSource, TTrace))
|
||||
-- ^ Run a 'MonadSTM' transaction atomically.
|
||||
stepThread :: forall n r g. MonadRef r n
|
||||
=> Scheduler g
|
||||
-- ^ The scheduler.
|
||||
-> MemType
|
||||
-- ^ The memory model
|
||||
-> Action n r s
|
||||
-- ^ Action to step
|
||||
-> IdSource
|
||||
-- ^ Source of fresh IDs
|
||||
-- ^ The memory model to use.
|
||||
-> ThreadId
|
||||
-- ^ ID of the current thread
|
||||
-> Threads n r s
|
||||
-- ^ Current state of threads
|
||||
-> WriteBuffer r
|
||||
-- ^ @CRef@ write buffer
|
||||
-> Int
|
||||
-- ^ The number of capabilities
|
||||
-> n (Either Failure (Threads n r s, IdSource, ThreadAction, WriteBuffer r, Int))
|
||||
stepThread runstm memtype action idSource tid threads wb caps = case action of
|
||||
AFork n a b -> stepFork n a b
|
||||
AMyTId c -> stepMyTId c
|
||||
AGetNumCapabilities c -> stepGetNumCapabilities c
|
||||
ASetNumCapabilities i c -> stepSetNumCapabilities i c
|
||||
AYield c -> stepYield c
|
||||
ANewVar n c -> stepNewVar n c
|
||||
APutVar var a c -> stepPutVar var a c
|
||||
ATryPutVar var a c -> stepTryPutVar var a c
|
||||
AReadVar var c -> stepReadVar var c
|
||||
ATakeVar var c -> stepTakeVar var c
|
||||
ATryTakeVar var c -> stepTryTakeVar var c
|
||||
ANewRef n a c -> stepNewRef n a c
|
||||
AReadRef ref c -> stepReadRef ref c
|
||||
AReadRefCas ref c -> stepReadRefCas ref c
|
||||
AModRef ref f c -> stepModRef ref f c
|
||||
AModRefCas ref f c -> stepModRefCas ref f c
|
||||
AWriteRef ref a c -> stepWriteRef ref a c
|
||||
ACasRef ref tick a c -> stepCasRef ref tick a c
|
||||
ACommit t c -> stepCommit t c
|
||||
AAtom stm c -> stepAtom stm c
|
||||
ALift na -> stepLift na
|
||||
AThrow e -> stepThrow e
|
||||
AThrowTo t e c -> stepThrowTo t e c
|
||||
ACatching h ma c -> stepCatching h ma c
|
||||
APopCatching a -> stepPopCatching a
|
||||
AMasking m ma c -> stepMasking m ma c
|
||||
AResetMask b1 b2 m c -> stepResetMask b1 b2 m c
|
||||
AReturn c -> stepReturn c
|
||||
AMessage m c -> stepMessage m c
|
||||
AStop na -> stepStop na
|
||||
-> Action n r
|
||||
-- ^ Action to step
|
||||
-> Context n r g
|
||||
-- ^ The execution context.
|
||||
-> n (Either Failure (Context n r g, Either (ThreadAction, SeqTrace) ThreadAction))
|
||||
stepThread sched memtype tid action ctx = case action of
|
||||
-- start a new thread, assigning it the next 'ThreadId'
|
||||
AFork n a b -> pure . Right $
|
||||
let threads' = launch tid newtid a (cThreads ctx)
|
||||
(idSource', newtid) = nextTId n (cIdSource ctx)
|
||||
in (ctx { cThreads = goto (b newtid) tid threads', cIdSource = idSource' }, Right (Fork newtid))
|
||||
|
||||
where
|
||||
-- | Start a new thread, assigning it the next 'ThreadId'
|
||||
--
|
||||
-- Explicit type signature needed for GHC 8. Looks like the
|
||||
-- impredicative polymorphism checks got stronger.
|
||||
stepFork :: String
|
||||
-> ((forall b. M n r s b -> M n r s b) -> Action n r s)
|
||||
-> (ThreadId -> Action n r s)
|
||||
-> n (Either Failure (Threads n r s, IdSource, ThreadAction, WriteBuffer r, Int))
|
||||
stepFork n a b = return $ Right (goto (b newtid) tid threads', idSource', Fork newtid, wb, caps) where
|
||||
threads' = launch tid newtid a threads
|
||||
(idSource', newtid) = nextTId n idSource
|
||||
-- get the 'ThreadId' of the current thread
|
||||
AMyTId c -> simple (goto (c tid) tid (cThreads ctx)) MyThreadId
|
||||
|
||||
-- | Get the 'ThreadId' of the current thread
|
||||
stepMyTId c = simple (goto (c tid) tid threads) MyThreadId
|
||||
-- get the number of capabilities
|
||||
AGetNumCapabilities c -> simple (goto (c (cCaps ctx)) tid (cThreads ctx)) $ GetNumCapabilities (cCaps ctx)
|
||||
|
||||
-- | Get the number of capabilities
|
||||
stepGetNumCapabilities c = simple (goto (c caps) tid threads) $ GetNumCapabilities caps
|
||||
-- set the number of capabilities
|
||||
ASetNumCapabilities i c -> pure . Right $
|
||||
(ctx { cThreads = goto c tid (cThreads ctx), cCaps = i }, Right (SetNumCapabilities i))
|
||||
|
||||
-- | Set the number of capabilities
|
||||
stepSetNumCapabilities i c = return $ Right (goto c tid threads, idSource, SetNumCapabilities i, wb, i)
|
||||
-- yield the current thread
|
||||
AYield c -> simple (goto c tid (cThreads ctx)) Yield
|
||||
|
||||
-- | Yield the current thread
|
||||
stepYield c = simple (goto c tid threads) Yield
|
||||
-- create a new @MVar@, using the next 'MVarId'.
|
||||
ANewVar n c -> do
|
||||
let (idSource', newmvid) = nextMVId n (cIdSource ctx)
|
||||
ref <- newRef Nothing
|
||||
let mvar = MVar newmvid ref
|
||||
pure $ Right (ctx { cThreads = goto (c mvar) tid (cThreads ctx), cIdSource = idSource' }, Right (NewVar newmvid))
|
||||
|
||||
-- | Put a value into a @MVar@, blocking the thread until it's
|
||||
-- empty.
|
||||
stepPutVar cvar@(MVar cvid _) a c = synchronised $ do
|
||||
(success, threads', woken) <- putIntoMVar cvar a c tid threads
|
||||
-- put a value into a @MVar@, blocking the thread until it's empty.
|
||||
APutVar cvar@(MVar cvid _) a c -> synchronised $ do
|
||||
(success, threads', woken) <- putIntoMVar cvar a c tid (cThreads ctx)
|
||||
simple threads' $ if success then PutVar cvid woken else BlockedPutVar cvid
|
||||
|
||||
-- | Try to put a value into a @MVar@, without blocking.
|
||||
stepTryPutVar cvar@(MVar cvid _) a c = synchronised $ do
|
||||
(success, threads', woken) <- tryPutIntoMVar cvar a c tid threads
|
||||
-- try to put a value into a @MVar@, without blocking.
|
||||
ATryPutVar cvar@(MVar cvid _) a c -> synchronised $ do
|
||||
(success, threads', woken) <- tryPutIntoMVar cvar a c tid (cThreads ctx)
|
||||
simple threads' $ TryPutVar cvid success woken
|
||||
|
||||
-- | Get the value from a @MVar@, without emptying, blocking the
|
||||
-- get the value from a @MVar@, without emptying, blocking the
|
||||
-- thread until it's full.
|
||||
stepReadVar cvar@(MVar cvid _) c = synchronised $ do
|
||||
(success, threads', _) <- readFromMVar cvar c tid threads
|
||||
AReadVar cvar@(MVar cvid _) c -> synchronised $ do
|
||||
(success, threads', _) <- readFromMVar cvar c tid (cThreads ctx)
|
||||
simple threads' $ if success then ReadVar cvid else BlockedReadVar cvid
|
||||
|
||||
-- | Take the value from a @MVar@, blocking the thread until it's
|
||||
-- take the value from a @MVar@, blocking the thread until it's
|
||||
-- full.
|
||||
stepTakeVar cvar@(MVar cvid _) c = synchronised $ do
|
||||
(success, threads', woken) <- takeFromMVar cvar c tid threads
|
||||
ATakeVar cvar@(MVar cvid _) c -> synchronised $ do
|
||||
(success, threads', woken) <- takeFromMVar cvar c tid (cThreads ctx)
|
||||
simple threads' $ if success then TakeVar cvid woken else BlockedTakeVar cvid
|
||||
|
||||
-- | Try to take the value from a @MVar@, without blocking.
|
||||
stepTryTakeVar cvar@(MVar cvid _) c = synchronised $ do
|
||||
(success, threads', woken) <- tryTakeFromMVar cvar c tid threads
|
||||
-- try to take the value from a @MVar@, without blocking.
|
||||
ATryTakeVar cvar@(MVar cvid _) c -> synchronised $ do
|
||||
(success, threads', woken) <- tryTakeFromMVar cvar c tid (cThreads ctx)
|
||||
simple threads' $ TryTakeVar cvid success woken
|
||||
|
||||
-- | Read from a @CRef@.
|
||||
stepReadRef cref@(CRef crid _) c = do
|
||||
-- create a new @CRef@, using the next 'CRefId'.
|
||||
ANewRef n a c -> do
|
||||
let (idSource', newcrid) = nextCRId n (cIdSource ctx)
|
||||
ref <- newRef (M.empty, 0, a)
|
||||
let cref = CRef newcrid ref
|
||||
pure $ Right (ctx { cThreads = goto (c cref) tid (cThreads ctx), cIdSource = idSource' }, Right (NewRef newcrid))
|
||||
|
||||
-- read from a @CRef@.
|
||||
AReadRef cref@(CRef crid _) c -> do
|
||||
val <- readCRef cref tid
|
||||
simple (goto (c val) tid threads) $ ReadRef crid
|
||||
simple (goto (c val) tid (cThreads ctx)) $ ReadRef crid
|
||||
|
||||
-- | Read from a @CRef@ for future compare-and-swap operations.
|
||||
stepReadRefCas cref@(CRef crid _) c = do
|
||||
-- read from a @CRef@ for future compare-and-swap operations.
|
||||
AReadRefCas cref@(CRef crid _) c -> do
|
||||
tick <- readForTicket cref tid
|
||||
simple (goto (c tick) tid threads) $ ReadRefCas crid
|
||||
simple (goto (c tick) tid (cThreads ctx)) $ ReadRefCas crid
|
||||
|
||||
-- | Modify a @CRef@.
|
||||
stepModRef cref@(CRef crid _) f c = synchronised $ do
|
||||
-- modify a @CRef@.
|
||||
AModRef cref@(CRef crid _) f c -> synchronised $ do
|
||||
(new, val) <- f <$> readCRef cref tid
|
||||
writeImmediate cref new
|
||||
simple (goto (c val) tid threads) $ ModRef crid
|
||||
simple (goto (c val) tid (cThreads ctx)) $ ModRef crid
|
||||
|
||||
-- | Modify a @CRef@ using a compare-and-swap.
|
||||
stepModRefCas cref@(CRef crid _) f c = synchronised $ do
|
||||
-- modify a @CRef@ using a compare-and-swap.
|
||||
AModRefCas cref@(CRef crid _) f c -> synchronised $ do
|
||||
tick@(Ticket _ _ old) <- readForTicket cref tid
|
||||
let (new, val) = f old
|
||||
void $ casCRef cref tid tick new
|
||||
simple (goto (c val) tid threads) $ ModRefCas crid
|
||||
simple (goto (c val) tid (cThreads ctx)) $ ModRefCas crid
|
||||
|
||||
-- | Write to a @CRef@ without synchronising
|
||||
stepWriteRef cref@(CRef crid _) a c = case memtype of
|
||||
-- Write immediately.
|
||||
-- write to a @CRef@ without synchronising.
|
||||
AWriteRef cref@(CRef crid _) a c -> case memtype of
|
||||
-- write immediately.
|
||||
SequentialConsistency -> do
|
||||
writeImmediate cref a
|
||||
simple (goto c tid threads) $ WriteRef crid
|
||||
|
||||
-- Add to buffer using thread id.
|
||||
simple (goto c tid (cThreads ctx)) $ WriteRef crid
|
||||
-- add to buffer using thread id.
|
||||
TotalStoreOrder -> do
|
||||
wb' <- bufferWrite wb (tid, Nothing) cref a
|
||||
return $ Right (goto c tid threads, idSource, WriteRef crid, wb', caps)
|
||||
|
||||
-- Add to buffer using both thread id and cref id
|
||||
wb' <- bufferWrite (cWriteBuf ctx) (tid, Nothing) cref a
|
||||
pure $ Right (ctx { cThreads = goto c tid (cThreads ctx), cWriteBuf = wb' }, Right (WriteRef crid))
|
||||
-- add to buffer using both thread id and cref id
|
||||
PartialStoreOrder -> do
|
||||
wb' <- bufferWrite wb (tid, Just crid) cref a
|
||||
return $ Right (goto c tid threads, idSource, WriteRef crid, wb', caps)
|
||||
wb' <- bufferWrite (cWriteBuf ctx) (tid, Just crid) cref a
|
||||
pure $ Right (ctx { cThreads = goto c tid (cThreads ctx), cWriteBuf = wb' }, Right (WriteRef crid))
|
||||
|
||||
-- | Perform a compare-and-swap on a @CRef@.
|
||||
stepCasRef cref@(CRef crid _) tick a c = synchronised $ do
|
||||
-- perform a compare-and-swap on a @CRef@.
|
||||
ACasRef cref@(CRef crid _) tick a c -> synchronised $ do
|
||||
(suc, tick') <- casCRef cref tid tick a
|
||||
simple (goto (c (suc, tick')) tid threads) $ CasRef crid suc
|
||||
simple (goto (c (suc, tick')) tid (cThreads ctx)) $ CasRef crid suc
|
||||
|
||||
-- | Commit a @CRef@ write
|
||||
stepCommit t c = do
|
||||
-- commit a @CRef@ write
|
||||
ACommit t c -> do
|
||||
wb' <- case memtype of
|
||||
-- Shouldn't ever get here
|
||||
-- shouldn't ever get here
|
||||
SequentialConsistency ->
|
||||
error "Attempting to commit under SequentialConsistency"
|
||||
-- commit using the thread id.
|
||||
TotalStoreOrder -> commitWrite (cWriteBuf ctx) (t, Nothing)
|
||||
-- commit using the cref id.
|
||||
PartialStoreOrder -> commitWrite (cWriteBuf ctx) (t, Just c)
|
||||
pure $ Right (ctx { cWriteBuf = wb' }, Right (CommitRef t c))
|
||||
|
||||
-- Commit using the thread id.
|
||||
TotalStoreOrder -> commitWrite wb (t, Nothing)
|
||||
|
||||
-- Commit using the cref id.
|
||||
PartialStoreOrder -> commitWrite wb (t, Just c)
|
||||
|
||||
return $ Right (threads, idSource, CommitRef t c, wb', caps)
|
||||
|
||||
-- | Run a STM transaction atomically.
|
||||
stepAtom stm c = synchronised $ do
|
||||
(res, idSource', trace) <- runstm stm idSource
|
||||
-- run a STM transaction atomically.
|
||||
AAtom stm c -> synchronised $ do
|
||||
(res, idSource', trace) <- runTransaction stm (cIdSource ctx)
|
||||
case res of
|
||||
Success _ written val ->
|
||||
let (threads', woken) = wake (OnTVar written) threads
|
||||
in return $ Right (goto (c val) tid threads', idSource', STM trace woken, wb, caps)
|
||||
let (threads', woken) = wake (OnTVar written) (cThreads ctx)
|
||||
in pure $ Right (ctx { cThreads = goto (c val) tid threads', cIdSource = idSource' }, Right (STM trace woken))
|
||||
Retry touched ->
|
||||
let threads' = block (OnTVar touched) tid threads
|
||||
in return $ Right (threads', idSource', BlockedSTM trace, wb, caps)
|
||||
let threads' = block (OnTVar touched) tid (cThreads ctx)
|
||||
in pure $ Right (ctx { cThreads = threads', cIdSource = idSource'}, Right (BlockedSTM trace))
|
||||
Exception e -> do
|
||||
res' <- stepThrow e
|
||||
return $ case res' of
|
||||
Right (threads', _, _, _, _) -> Right (threads', idSource', Throw, wb, caps)
|
||||
pure $ case res' of
|
||||
Right (ctx', _) -> Right (ctx' { cIdSource = idSource' }, Right Throw)
|
||||
Left err -> Left err
|
||||
|
||||
-- | Run a subcomputation in an exception-catching context.
|
||||
stepCatching h ma c = simple threads' Catching where
|
||||
a = runCont ma (APopCatching . c)
|
||||
e exc = runCont (h exc) (APopCatching . c)
|
||||
-- lift an action from the underlying monad into the @Conc@
|
||||
-- computation.
|
||||
ALift na -> do
|
||||
a <- na
|
||||
simple (goto a tid (cThreads ctx)) LiftIO
|
||||
|
||||
threads' = goto a tid (catching e tid threads)
|
||||
|
||||
-- | Pop the top exception handler from the thread's stack.
|
||||
stepPopCatching a = simple threads' PopCatching where
|
||||
threads' = goto a tid (uncatching tid threads)
|
||||
|
||||
-- | Throw an exception, and propagate it to the appropriate
|
||||
-- throw an exception, and propagate it to the appropriate
|
||||
-- handler.
|
||||
stepThrow e =
|
||||
case propagate (toException e) tid threads of
|
||||
Just threads' -> simple threads' Throw
|
||||
Nothing -> return $ Left UncaughtException
|
||||
AThrow e -> stepThrow e
|
||||
|
||||
-- | Throw an exception to the target thread, and propagate it to
|
||||
-- throw an exception to the target thread, and propagate it to
|
||||
-- the appropriate handler.
|
||||
stepThrowTo t e c = synchronised $
|
||||
let threads' = goto c tid threads
|
||||
blocked = block (OnMask t) tid threads
|
||||
in case M.lookup t threads of
|
||||
AThrowTo t e c -> synchronised $
|
||||
let threads' = goto c tid (cThreads ctx)
|
||||
blocked = block (OnMask t) tid (cThreads ctx)
|
||||
in case M.lookup t (cThreads ctx) of
|
||||
Just thread
|
||||
| interruptible thread -> case propagate (toException e) t threads' of
|
||||
Just threads'' -> simple threads'' $ ThrowTo t
|
||||
Nothing
|
||||
| t == initialThread -> return $ Left UncaughtException
|
||||
| t == initialThread -> pure $ Left UncaughtException
|
||||
| otherwise -> simple (kill t threads') $ ThrowTo t
|
||||
| otherwise -> simple blocked $ BlockedThrowTo t
|
||||
Nothing -> simple threads' $ ThrowTo t
|
||||
|
||||
-- | Execute a subcomputation with a new masking state, and give
|
||||
-- it a function to run a computation with the current masking
|
||||
-- state.
|
||||
--
|
||||
-- Explicit type sig necessary for checking in the prescence of
|
||||
-- 'umask', sadly.
|
||||
stepMasking :: MaskingState
|
||||
-> ((forall b. M n r s b -> M n r s b) -> M n r s a)
|
||||
-> (a -> Action n r s)
|
||||
-> n (Either Failure (Threads n r s, IdSource, ThreadAction, WriteBuffer r, Int))
|
||||
stepMasking m ma c = simple threads' $ SetMasking False m where
|
||||
a = runCont (ma umask) (AResetMask False False m' . c)
|
||||
-- run a subcomputation in an exception-catching context.
|
||||
ACatching h ma c ->
|
||||
let a = runCont ma (APopCatching . c)
|
||||
e exc = runCont (h exc) (APopCatching . c)
|
||||
threads' = goto a tid (catching e tid (cThreads ctx))
|
||||
in simple threads' Catching
|
||||
|
||||
m' = _masking . fromJust $ M.lookup tid threads
|
||||
umask mb = resetMask True m' >> mb >>= \b -> resetMask False m >> return b
|
||||
resetMask typ ms = cont $ \k -> AResetMask typ True ms $ k ()
|
||||
-- pop the top exception handler from the thread's stack.
|
||||
APopCatching a ->
|
||||
let threads' = goto a tid (uncatching tid (cThreads ctx))
|
||||
in simple threads' PopCatching
|
||||
|
||||
threads' = goto a tid (mask m tid threads)
|
||||
-- execute a subcomputation with a new masking state, and give it
|
||||
-- a function to run a computation with the current masking state.
|
||||
AMasking m ma c ->
|
||||
let a = runCont (ma umask) (AResetMask False False m' . c)
|
||||
m' = _masking . fromJust $ M.lookup tid (cThreads ctx)
|
||||
umask mb = resetMask True m' >> mb >>= \b -> resetMask False m >> pure b
|
||||
resetMask typ ms = cont $ \k -> AResetMask typ True ms $ k ()
|
||||
threads' = goto a tid (mask m tid (cThreads ctx))
|
||||
in simple threads' $ SetMasking False m
|
||||
|
||||
-- | Reset the masking thread of the state.
|
||||
stepResetMask b1 b2 m c = simple threads' act where
|
||||
act = (if b1 then SetMasking else ResetMasking) b2 m
|
||||
threads' = goto c tid (mask m tid threads)
|
||||
|
||||
-- | Create a new @MVar@, using the next 'MVarId'.
|
||||
stepNewVar n c = do
|
||||
let (idSource', newmvid) = nextMVId n idSource
|
||||
ref <- newRef Nothing
|
||||
let mvar = MVar newmvid ref
|
||||
return $ Right (goto (c mvar) tid threads, idSource', NewVar newmvid, wb, caps)
|
||||
-- reset the masking thread of the state.
|
||||
AResetMask b1 b2 m c ->
|
||||
let act = (if b1 then SetMasking else ResetMasking) b2 m
|
||||
threads' = goto c tid (mask m tid (cThreads ctx))
|
||||
in simple threads' act
|
||||
|
||||
-- | Create a new @CRef@, using the next 'CRefId'.
|
||||
stepNewRef n a c = do
|
||||
let (idSource', newcrid) = nextCRId n idSource
|
||||
ref <- newRef (M.empty, 0, a)
|
||||
let cref = CRef newcrid ref
|
||||
return $ Right (goto (c cref) tid threads, idSource', NewRef newcrid, wb, caps)
|
||||
-- execute a 'return' or 'pure'.
|
||||
AReturn c -> simple (goto c tid (cThreads ctx)) Return
|
||||
|
||||
-- | Lift an action from the underlying monad into the @Conc@
|
||||
-- computation.
|
||||
stepLift na = do
|
||||
a <- na
|
||||
simple (goto a tid threads) LiftIO
|
||||
-- add a message to the trace.
|
||||
AMessage m c -> simple (goto c tid (cThreads ctx)) (Message m)
|
||||
|
||||
-- | Execute a 'return' or 'pure'.
|
||||
stepReturn c = simple (goto c tid threads) Return
|
||||
-- kill the current thread.
|
||||
AStop na -> na >> simple (kill tid (cThreads ctx)) Stop
|
||||
|
||||
-- | Add a message to the trace.
|
||||
stepMessage m c = simple (goto c tid threads) (Message m)
|
||||
-- run a subconcurrent computation.
|
||||
ASub ma c
|
||||
| M.size (cThreads ctx) > 1 -> pure (Left IllegalSubconcurrency)
|
||||
| otherwise -> do
|
||||
(res, g', trace) <- runConcurrency sched memtype (cSchedState ctx) ma
|
||||
pure $ Right (ctx { cThreads = goto (c res) tid (cThreads ctx), cSchedState = g' }, Left (Subconcurrency, trace))
|
||||
where
|
||||
|
||||
-- | Kill the current thread.
|
||||
stepStop na = na >> simple (kill tid threads) Stop
|
||||
-- this is not inline in the long @case@ above as it's needed by
|
||||
-- @AAtom@, @AThrow@, and @AThrowTo@.
|
||||
stepThrow e =
|
||||
case propagate (toException e) tid (cThreads ctx) of
|
||||
Just threads' -> simple threads' Throw
|
||||
Nothing -> pure $ Left UncaughtException
|
||||
|
||||
-- | Helper for actions which don't touch the 'IdSource' or
|
||||
-- 'WriteBuffer'
|
||||
simple threads' act = return $ Right (threads', idSource, act, wb, caps)
|
||||
-- helper for actions which only change the threads.
|
||||
simple threads' act = pure $ Right (ctx { cThreads = threads' }, Right act)
|
||||
|
||||
-- | Helper for actions impose a write barrier.
|
||||
-- helper for actions impose a write barrier.
|
||||
synchronised ma = do
|
||||
writeBarrier wb
|
||||
writeBarrier (cWriteBuf ctx)
|
||||
res <- ma
|
||||
|
||||
return $ case res of
|
||||
Right (threads', idSource', act', _, caps') -> Right (threads', idSource', act', emptyBuffer, caps')
|
||||
Right (ctx', act) -> Right (ctx' { cWriteBuf = emptyBuffer }, act)
|
||||
_ -> res
|
||||
|
@ -18,6 +18,7 @@ import Data.Dynamic (Dynamic)
|
||||
import Data.Map.Strict (Map)
|
||||
import Data.List.NonEmpty (NonEmpty, fromList)
|
||||
import Test.DejaFu.Common
|
||||
import Test.DejaFu.STM (STMLike)
|
||||
|
||||
{-# ANN module ("HLint: ignore Use record patterns" :: String) #-}
|
||||
|
||||
@ -32,16 +33,16 @@ import Test.DejaFu.Common
|
||||
-- current expression of threads and exception handlers very difficult
|
||||
-- (perhaps even not possible without significant reworking), so I
|
||||
-- abandoned the attempt.
|
||||
newtype M n r s a = M { runM :: (a -> Action n r s) -> Action n r s }
|
||||
newtype M n r a = M { runM :: (a -> Action n r) -> Action n r }
|
||||
|
||||
instance Functor (M n r s) where
|
||||
instance Functor (M n r) where
|
||||
fmap f m = M $ \ c -> runM m (c . f)
|
||||
|
||||
instance Applicative (M n r s) where
|
||||
instance Applicative (M n r) where
|
||||
pure x = M $ \c -> AReturn $ c x
|
||||
f <*> v = M $ \c -> runM f (\g -> runM v (c . g))
|
||||
|
||||
instance Monad (M n r s) where
|
||||
instance Monad (M n r) where
|
||||
return = pure
|
||||
m >>= k = M $ \c -> runM m (\x -> runM (k x) c)
|
||||
|
||||
@ -82,11 +83,11 @@ data Ticket a = Ticket
|
||||
}
|
||||
|
||||
-- | Construct a continuation-passing operation from a function.
|
||||
cont :: ((a -> Action n r s) -> Action n r s) -> M n r s a
|
||||
cont :: ((a -> Action n r) -> Action n r) -> M n r a
|
||||
cont = M
|
||||
|
||||
-- | Run a CPS computation with the given final computation.
|
||||
runCont :: M n r s a -> (a -> Action n r s) -> Action n r s
|
||||
runCont :: M n r a -> (a -> Action n r) -> Action n r
|
||||
runCont = runM
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
@ -96,49 +97,51 @@ runCont = runM
|
||||
-- only occur as a result of an action, and they cover (most of) the
|
||||
-- primitives of the concurrency. 'spawn' is absent as it is
|
||||
-- implemented in terms of 'newEmptyMVar', 'fork', and 'putMVar'.
|
||||
data Action n r s =
|
||||
AFork String ((forall b. M n r s b -> M n r s b) -> Action n r s) (ThreadId -> Action n r s)
|
||||
| AMyTId (ThreadId -> Action n r s)
|
||||
data Action n r =
|
||||
AFork String ((forall b. M n r b -> M n r b) -> Action n r) (ThreadId -> Action n r)
|
||||
| AMyTId (ThreadId -> Action n r)
|
||||
|
||||
| AGetNumCapabilities (Int -> Action n r s)
|
||||
| ASetNumCapabilities Int (Action n r s)
|
||||
| AGetNumCapabilities (Int -> Action n r)
|
||||
| ASetNumCapabilities Int (Action n r)
|
||||
|
||||
| forall a. ANewVar String (MVar r a -> Action n r s)
|
||||
| forall a. APutVar (MVar r a) a (Action n r s)
|
||||
| forall a. ATryPutVar (MVar r a) a (Bool -> Action n r s)
|
||||
| forall a. AReadVar (MVar r a) (a -> Action n r s)
|
||||
| forall a. ATakeVar (MVar r a) (a -> Action n r s)
|
||||
| forall a. ATryTakeVar (MVar r a) (Maybe a -> Action n r s)
|
||||
| forall a. ANewVar String (MVar r a -> Action n r)
|
||||
| forall a. APutVar (MVar r a) a (Action n r)
|
||||
| forall a. ATryPutVar (MVar r a) a (Bool -> Action n r)
|
||||
| forall a. AReadVar (MVar r a) (a -> Action n r)
|
||||
| forall a. ATakeVar (MVar r a) (a -> Action n r)
|
||||
| forall a. ATryTakeVar (MVar r a) (Maybe a -> Action n r)
|
||||
|
||||
| forall a. ANewRef String a (CRef r a -> Action n r s)
|
||||
| forall a. AReadRef (CRef r a) (a -> Action n r s)
|
||||
| forall a. AReadRefCas (CRef r a) (Ticket a -> Action n r s)
|
||||
| forall a b. AModRef (CRef r a) (a -> (a, b)) (b -> Action n r s)
|
||||
| forall a b. AModRefCas (CRef r a) (a -> (a, b)) (b -> Action n r s)
|
||||
| forall a. AWriteRef (CRef r a) a (Action n r s)
|
||||
| forall a. ACasRef (CRef r a) (Ticket a) a ((Bool, Ticket a) -> Action n r s)
|
||||
| forall a. ANewRef String a (CRef r a -> Action n r)
|
||||
| forall a. AReadRef (CRef r a) (a -> Action n r)
|
||||
| forall a. AReadRefCas (CRef r a) (Ticket a -> Action n r)
|
||||
| forall a b. AModRef (CRef r a) (a -> (a, b)) (b -> Action n r)
|
||||
| forall a b. AModRefCas (CRef r a) (a -> (a, b)) (b -> Action n r)
|
||||
| forall a. AWriteRef (CRef r a) a (Action n r)
|
||||
| forall a. ACasRef (CRef r a) (Ticket a) a ((Bool, Ticket a) -> Action n r)
|
||||
|
||||
| forall e. Exception e => AThrow e
|
||||
| forall e. Exception e => AThrowTo ThreadId e (Action n r s)
|
||||
| forall a e. Exception e => ACatching (e -> M n r s a) (M n r s a) (a -> Action n r s)
|
||||
| APopCatching (Action n r s)
|
||||
| forall a. AMasking MaskingState ((forall b. M n r s b -> M n r s b) -> M n r s a) (a -> Action n r s)
|
||||
| AResetMask Bool Bool MaskingState (Action n r s)
|
||||
| forall e. Exception e => AThrowTo ThreadId e (Action n r)
|
||||
| forall a e. Exception e => ACatching (e -> M n r a) (M n r a) (a -> Action n r)
|
||||
| APopCatching (Action n r)
|
||||
| forall a. AMasking MaskingState ((forall b. M n r b -> M n r b) -> M n r a) (a -> Action n r)
|
||||
| AResetMask Bool Bool MaskingState (Action n r)
|
||||
|
||||
| AMessage Dynamic (Action n r s)
|
||||
| AMessage Dynamic (Action n r)
|
||||
|
||||
| forall a. AAtom (s a) (a -> Action n r s)
|
||||
| ALift (n (Action n r s))
|
||||
| AYield (Action n r s)
|
||||
| AReturn (Action n r s)
|
||||
| forall a. AAtom (STMLike n r a) (a -> Action n r)
|
||||
| ALift (n (Action n r))
|
||||
| AYield (Action n r)
|
||||
| AReturn (Action n r)
|
||||
| ACommit ThreadId CRefId
|
||||
| AStop (n ())
|
||||
|
||||
| forall a. ASub (M n r a) (Either Failure a -> Action n r)
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- * Scheduling & Traces
|
||||
|
||||
-- | Look as far ahead in the given continuation as possible.
|
||||
lookahead :: Action n r s -> NonEmpty Lookahead
|
||||
lookahead :: Action n r -> NonEmpty Lookahead
|
||||
lookahead = fromList . lookahead' where
|
||||
lookahead' (AFork _ _ _) = [WillFork]
|
||||
lookahead' (AMyTId _) = [WillMyThreadId]
|
||||
@ -170,3 +173,4 @@ lookahead = fromList . lookahead' where
|
||||
lookahead' (AYield k) = WillYield : lookahead' k
|
||||
lookahead' (AReturn k) = WillReturn : lookahead' k
|
||||
lookahead' (AStop _) = [WillStop]
|
||||
lookahead' (ASub _ _) = [WillSubconcurrency]
|
||||
|
@ -126,7 +126,7 @@ writeBarrier (WriteBuffer wb) = mapM_ flush $ M.elems wb where
|
||||
flush = mapM_ $ \(BufferedWrite _ cref a) -> writeImmediate cref a
|
||||
|
||||
-- | Add phantom threads to the thread list to commit pending writes.
|
||||
addCommitThreads :: WriteBuffer r -> Threads n r s -> Threads n r s
|
||||
addCommitThreads :: WriteBuffer r -> Threads n r -> Threads n r
|
||||
addCommitThreads (WriteBuffer wb) ts = ts <> M.fromList phantoms where
|
||||
phantoms = [ (ThreadId Nothing $ negate tid, mkthread $ fromJust c)
|
||||
| ((k, b), tid) <- zip (M.toList wb) [1..]
|
||||
@ -136,41 +136,41 @@ addCommitThreads (WriteBuffer wb) ts = ts <> M.fromList phantoms where
|
||||
go EmptyL = Nothing
|
||||
|
||||
-- | Remove phantom threads.
|
||||
delCommitThreads :: Threads n r s -> Threads n r s
|
||||
delCommitThreads :: Threads n r -> Threads n r
|
||||
delCommitThreads = M.filterWithKey $ \k _ -> k >= initialThread
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- * Manipulating @MVar@s
|
||||
|
||||
-- | Put into a @MVar@, blocking if full.
|
||||
putIntoMVar :: MonadRef r n => MVar r a -> a -> Action n r s
|
||||
-> ThreadId -> Threads n r s -> n (Bool, Threads n r s, [ThreadId])
|
||||
putIntoMVar :: MonadRef r n => MVar r a -> a -> Action n r
|
||||
-> ThreadId -> Threads n r -> n (Bool, Threads n r, [ThreadId])
|
||||
putIntoMVar cvar a c = mutMVar True cvar a (const c)
|
||||
|
||||
-- | Try to put into a @MVar@, not blocking if full.
|
||||
tryPutIntoMVar :: MonadRef r n => MVar r a -> a -> (Bool -> Action n r s)
|
||||
-> ThreadId -> Threads n r s -> n (Bool, Threads n r s, [ThreadId])
|
||||
tryPutIntoMVar :: MonadRef r n => MVar r a -> a -> (Bool -> Action n r)
|
||||
-> ThreadId -> Threads n r -> n (Bool, Threads n r, [ThreadId])
|
||||
tryPutIntoMVar = mutMVar False
|
||||
|
||||
-- | Read from a @MVar@, blocking if empty.
|
||||
readFromMVar :: MonadRef r n => MVar r a -> (a -> Action n r s)
|
||||
-> ThreadId -> Threads n r s -> n (Bool, Threads n r s, [ThreadId])
|
||||
readFromMVar :: MonadRef r n => MVar r a -> (a -> Action n r)
|
||||
-> ThreadId -> Threads n r -> n (Bool, Threads n r, [ThreadId])
|
||||
readFromMVar cvar c = seeMVar False True cvar (c . fromJust)
|
||||
|
||||
-- | Take from a @MVar@, blocking if empty.
|
||||
takeFromMVar :: MonadRef r n => MVar r a -> (a -> Action n r s)
|
||||
-> ThreadId -> Threads n r s -> n (Bool, Threads n r s, [ThreadId])
|
||||
takeFromMVar :: MonadRef r n => MVar r a -> (a -> Action n r)
|
||||
-> ThreadId -> Threads n r -> n (Bool, Threads n r, [ThreadId])
|
||||
takeFromMVar cvar c = seeMVar True True cvar (c . fromJust)
|
||||
|
||||
-- | Try to take from a @MVar@, not blocking if empty.
|
||||
tryTakeFromMVar :: MonadRef r n => MVar r a -> (Maybe a -> Action n r s)
|
||||
-> ThreadId -> Threads n r s -> n (Bool, Threads n r s, [ThreadId])
|
||||
tryTakeFromMVar :: MonadRef r n => MVar r a -> (Maybe a -> Action n r)
|
||||
-> ThreadId -> Threads n r -> n (Bool, Threads n r, [ThreadId])
|
||||
tryTakeFromMVar = seeMVar True False
|
||||
|
||||
-- | Mutate a @MVar@, in either a blocking or nonblocking way.
|
||||
mutMVar :: MonadRef r n
|
||||
=> Bool -> MVar r a -> a -> (Bool -> Action n r s)
|
||||
-> ThreadId -> Threads n r s -> n (Bool, Threads n r s, [ThreadId])
|
||||
=> Bool -> MVar r a -> a -> (Bool -> Action n r)
|
||||
-> ThreadId -> Threads n r -> n (Bool, Threads n r, [ThreadId])
|
||||
mutMVar blocking (MVar cvid ref) a c threadid threads = do
|
||||
val <- readRef ref
|
||||
|
||||
@ -191,8 +191,8 @@ mutMVar blocking (MVar cvid ref) a c threadid threads = do
|
||||
-- | Read a @MVar@, in either a blocking or nonblocking
|
||||
-- way.
|
||||
seeMVar :: MonadRef r n
|
||||
=> Bool -> Bool -> MVar r a -> (Maybe a -> Action n r s)
|
||||
-> ThreadId -> Threads n r s -> n (Bool, Threads n r s, [ThreadId])
|
||||
=> Bool -> Bool -> MVar r a -> (Maybe a -> Action n r)
|
||||
-> ThreadId -> Threads n r -> n (Bool, Threads n r, [ThreadId])
|
||||
seeMVar emptying blocking (MVar cvid ref) c threadid threads = do
|
||||
val <- readRef ref
|
||||
|
||||
|
@ -27,22 +27,22 @@ import qualified Data.Map.Strict as M
|
||||
-- * Threads
|
||||
|
||||
-- | Threads are stored in a map index by 'ThreadId'.
|
||||
type Threads n r s = Map ThreadId (Thread n r s)
|
||||
type Threads n r = Map ThreadId (Thread n r)
|
||||
|
||||
-- | All the state of a thread.
|
||||
data Thread n r s = Thread
|
||||
{ _continuation :: Action n r s
|
||||
data Thread n r = Thread
|
||||
{ _continuation :: Action n r
|
||||
-- ^ The next action to execute.
|
||||
, _blocking :: Maybe BlockedOn
|
||||
-- ^ The state of any blocks.
|
||||
, _handlers :: [Handler n r s]
|
||||
, _handlers :: [Handler n r]
|
||||
-- ^ Stack of exception handlers
|
||||
, _masking :: MaskingState
|
||||
-- ^ The exception masking state.
|
||||
}
|
||||
|
||||
-- | Construct a thread with just one action
|
||||
mkthread :: Action n r s -> Thread n r s
|
||||
mkthread :: Action n r -> Thread n r
|
||||
mkthread c = Thread c Nothing [] Unmasked
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
@ -53,7 +53,7 @@ mkthread c = Thread c Nothing [] Unmasked
|
||||
data BlockedOn = OnMVarFull MVarId | OnMVarEmpty MVarId | OnTVar [TVarId] | OnMask ThreadId deriving Eq
|
||||
|
||||
-- | Determine if a thread is blocked in a certain way.
|
||||
(~=) :: Thread n r s -> BlockedOn -> Bool
|
||||
(~=) :: Thread n r -> BlockedOn -> Bool
|
||||
thread ~= theblock = case (_blocking thread, theblock) of
|
||||
(Just (OnMVarFull _), OnMVarFull _) -> True
|
||||
(Just (OnMVarEmpty _), OnMVarEmpty _) -> True
|
||||
@ -65,11 +65,11 @@ thread ~= theblock = case (_blocking thread, theblock) of
|
||||
-- * Exceptions
|
||||
|
||||
-- | An exception handler.
|
||||
data Handler n r s = forall e. Exception e => Handler (e -> Action n r s)
|
||||
data Handler n r = forall e. Exception e => Handler (e -> Action n r)
|
||||
|
||||
-- | Propagate an exception upwards, finding the closest handler
|
||||
-- which can deal with it.
|
||||
propagate :: SomeException -> ThreadId -> Threads n r s -> Maybe (Threads n r s)
|
||||
propagate :: SomeException -> ThreadId -> Threads n r -> Maybe (Threads n r)
|
||||
propagate e tid threads = case M.lookup tid threads >>= go . _handlers of
|
||||
Just (act, hs) -> Just $ except act hs tid threads
|
||||
Nothing -> Nothing
|
||||
@ -79,40 +79,40 @@ propagate e tid threads = case M.lookup tid threads >>= go . _handlers of
|
||||
go (Handler h:hs) = maybe (go hs) (\act -> Just (act, hs)) $ h <$> fromException e
|
||||
|
||||
-- | Check if a thread can be interrupted by an exception.
|
||||
interruptible :: Thread n r s -> Bool
|
||||
interruptible :: Thread n r -> Bool
|
||||
interruptible thread = _masking thread == Unmasked || (_masking thread == MaskedInterruptible && isJust (_blocking thread))
|
||||
|
||||
-- | Register a new exception handler.
|
||||
catching :: Exception e => (e -> Action n r s) -> ThreadId -> Threads n r s -> Threads n r s
|
||||
catching :: Exception e => (e -> Action n r) -> ThreadId -> Threads n r -> Threads n r
|
||||
catching h = M.alter $ \(Just thread) -> Just $ thread { _handlers = Handler h : _handlers thread }
|
||||
|
||||
-- | Remove the most recent exception handler.
|
||||
uncatching :: ThreadId -> Threads n r s -> Threads n r s
|
||||
uncatching :: ThreadId -> Threads n r -> Threads n r
|
||||
uncatching = M.alter $ \(Just thread) -> Just $ thread { _handlers = tail $ _handlers thread }
|
||||
|
||||
-- | Raise an exception in a thread.
|
||||
except :: Action n r s -> [Handler n r s] -> ThreadId -> Threads n r s -> Threads n r s
|
||||
except :: Action n r -> [Handler n r] -> ThreadId -> Threads n r -> Threads n r
|
||||
except act hs = M.alter $ \(Just thread) -> Just $ thread { _continuation = act, _handlers = hs, _blocking = Nothing }
|
||||
|
||||
-- | Set the masking state of a thread.
|
||||
mask :: MaskingState -> ThreadId -> Threads n r s -> Threads n r s
|
||||
mask :: MaskingState -> ThreadId -> Threads n r -> Threads n r
|
||||
mask ms = M.alter $ \(Just thread) -> Just $ thread { _masking = ms }
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- * Manipulating threads
|
||||
|
||||
-- | Replace the @Action@ of a thread.
|
||||
goto :: Action n r s -> ThreadId -> Threads n r s -> Threads n r s
|
||||
goto :: Action n r -> ThreadId -> Threads n r -> Threads n r
|
||||
goto a = M.alter $ \(Just thread) -> Just (thread { _continuation = a })
|
||||
|
||||
-- | Start a thread with the given ID, inheriting the masking state
|
||||
-- from the parent thread. This ID must not already be in use!
|
||||
launch :: ThreadId -> ThreadId -> ((forall b. M n r s b -> M n r s b) -> Action n r s) -> Threads n r s -> Threads n r s
|
||||
launch :: ThreadId -> ThreadId -> ((forall b. M n r b -> M n r b) -> Action n r) -> Threads n r -> Threads n r
|
||||
launch parent tid a threads = launch' ms tid a threads where
|
||||
ms = fromMaybe Unmasked $ _masking <$> M.lookup parent threads
|
||||
|
||||
-- | Start a thread with the given ID and masking state. This must not already be in use!
|
||||
launch' :: MaskingState -> ThreadId -> ((forall b. M n r s b -> M n r s b) -> Action n r s) -> Threads n r s -> Threads n r s
|
||||
launch' :: MaskingState -> ThreadId -> ((forall b. M n r b -> M n r b) -> Action n r) -> Threads n r -> Threads n r
|
||||
launch' ms tid a = M.insert tid thread where
|
||||
thread = Thread { _continuation = a umask, _blocking = Nothing, _handlers = [], _masking = ms }
|
||||
|
||||
@ -120,11 +120,11 @@ launch' ms tid a = M.insert tid thread where
|
||||
resetMask typ m = cont $ \k -> AResetMask typ True m $ k ()
|
||||
|
||||
-- | Kill a thread.
|
||||
kill :: ThreadId -> Threads n r s -> Threads n r s
|
||||
kill :: ThreadId -> Threads n r -> Threads n r
|
||||
kill = M.delete
|
||||
|
||||
-- | Block a thread.
|
||||
block :: BlockedOn -> ThreadId -> Threads n r s -> Threads n r s
|
||||
block :: BlockedOn -> ThreadId -> Threads n r -> Threads n r
|
||||
block blockedOn = M.alter doBlock where
|
||||
doBlock (Just thread) = Just $ thread { _blocking = Just blockedOn }
|
||||
doBlock _ = error "Invariant failure in 'block': thread does NOT exist!"
|
||||
@ -132,7 +132,7 @@ block blockedOn = M.alter doBlock where
|
||||
-- | Unblock all threads waiting on the appropriate block. For 'TVar'
|
||||
-- blocks, this will wake all threads waiting on at least one of the
|
||||
-- given 'TVar's.
|
||||
wake :: BlockedOn -> Threads n r s -> (Threads n r s, [ThreadId])
|
||||
wake :: BlockedOn -> Threads n r -> (Threads n r, [ThreadId])
|
||||
wake blockedOn threads = (unblock <$> threads, M.keys $ M.filter isBlocked threads) where
|
||||
unblock thread
|
||||
| isBlocked thread = thread { _blocking = Nothing }
|
||||
|
@ -79,9 +79,8 @@ module Test.DejaFu.SCT
|
||||
, sctLengthBound
|
||||
) where
|
||||
|
||||
import Control.DeepSeq (NFData(..))
|
||||
import Control.Monad.Ref (MonadRef)
|
||||
import Data.List (nub)
|
||||
import Data.List (foldl')
|
||||
import qualified Data.Map.Strict as M
|
||||
import Data.Maybe (isJust, fromJust)
|
||||
import qualified Data.Set as S
|
||||
@ -140,17 +139,16 @@ cBound (Bounds pb fb lb) =
|
||||
--
|
||||
-- If no bounds are enabled, just backtrack to the given point.
|
||||
cBacktrack :: Bounds -> BacktrackFunc
|
||||
cBacktrack (Bounds Nothing Nothing Nothing) bs i t = backtrackAt (const False) False bs i t
|
||||
cBacktrack (Bounds pb fb lb) bs i t = lBack . fBack $ pBack bs where
|
||||
pBack backs = if isJust pb then pBacktrack backs i t else backs
|
||||
fBack backs = if isJust fb then fBacktrack backs i t else backs
|
||||
lBack backs = if isJust lb then lBacktrack backs i t else backs
|
||||
cBacktrack (Bounds (Just _) _ _) = pBacktrack
|
||||
cBacktrack (Bounds _ (Just _) _) = fBacktrack
|
||||
cBacktrack (Bounds _ _ (Just _)) = lBacktrack
|
||||
cBacktrack _ = backtrackAt (\_ _ -> False)
|
||||
|
||||
-------------------------------------------------------------------------------
|
||||
-- Pre-emption bounding
|
||||
|
||||
newtype PreemptionBound = PreemptionBound Int
|
||||
deriving (NFData, Enum, Eq, Ord, Num, Real, Integral, Read, Show)
|
||||
deriving (Enum, Eq, Ord, Num, Real, Integral, Read, Show)
|
||||
|
||||
-- | A sensible default preemption bound: 2.
|
||||
--
|
||||
@ -181,31 +179,26 @@ pBound (PreemptionBound pb) ts dl = preEmpCount ts dl <= pb
|
||||
-- the same state being reached multiple times, but is needed because
|
||||
-- of the artificial dependency imposed by the bound.
|
||||
pBacktrack :: BacktrackFunc
|
||||
pBacktrack bs i tid =
|
||||
maybe id (\j' b -> backtrack True b j' tid) j $ backtrack False bs i tid
|
||||
pBacktrack bs = backtrackAt (\_ _ -> False) bs . concatMap addConservative where
|
||||
addConservative o@(i, _, tid) = o : case conservative i of
|
||||
Just j -> [(j, True, tid)]
|
||||
Nothing -> []
|
||||
|
||||
where
|
||||
-- Index of the conservative point
|
||||
j = goJ . reverse . pairs $ zip [0..i-1] bs where
|
||||
goJ (((_,b1), (j',b2)):rest)
|
||||
| bcktThreadid b1 /= bcktThreadid b2
|
||||
&& not (isCommitRef . snd $ bcktDecision b1)
|
||||
&& not (isCommitRef . snd $ bcktDecision b2) = Just j'
|
||||
| otherwise = goJ rest
|
||||
goJ [] = Nothing
|
||||
|
||||
-- List of adjacent pairs
|
||||
{-# INLINE pairs #-}
|
||||
pairs = zip <*> tail
|
||||
|
||||
-- Add a backtracking point.
|
||||
backtrack = backtrackAt $ const False
|
||||
-- index of conservative point
|
||||
conservative i = go (reverse (take (i-1) bs)) (i-1) where
|
||||
go _ (-1) = Nothing
|
||||
go (b1:rest@(b2:_)) j
|
||||
| bcktThreadid b1 /= bcktThreadid b2
|
||||
&& not (isCommitRef $ bcktAction b1)
|
||||
&& not (isCommitRef $ bcktAction b2) = Just j
|
||||
| otherwise = go rest (j-1)
|
||||
go _ _ = Nothing
|
||||
|
||||
-------------------------------------------------------------------------------
|
||||
-- Fair bounding
|
||||
|
||||
newtype FairBound = FairBound Int
|
||||
deriving (NFData, Enum, Eq, Ord, Num, Real, Integral, Read, Show)
|
||||
deriving (Enum, Eq, Ord, Num, Real, Integral, Read, Show)
|
||||
|
||||
-- | A sensible default fair bound: 5.
|
||||
--
|
||||
@ -233,15 +226,15 @@ fBound (FairBound fb) ts (_, l) = maxYieldCountDiff ts l <= fb
|
||||
-- | Add a backtrack point. If the thread isn't runnable, or performs
|
||||
-- a release operation, add all runnable threads.
|
||||
fBacktrack :: BacktrackFunc
|
||||
fBacktrack bs i t = backtrackAt check False bs i t where
|
||||
fBacktrack = backtrackAt check where
|
||||
-- True if a release operation is performed.
|
||||
check b = Just True == (willRelease <$> M.lookup t (bcktRunnable b))
|
||||
check t b = Just True == (willRelease <$> M.lookup t (bcktRunnable b))
|
||||
|
||||
-------------------------------------------------------------------------------
|
||||
-- Length bounding
|
||||
|
||||
newtype LengthBound = LengthBound Int
|
||||
deriving (NFData, Enum, Eq, Ord, Num, Real, Integral, Read, Show)
|
||||
deriving (Enum, Eq, Ord, Num, Real, Integral, Read, Show)
|
||||
|
||||
-- | A sensible default length bound: 250.
|
||||
--
|
||||
@ -269,7 +262,7 @@ lBound (LengthBound lb) ts _ = length ts < lb
|
||||
-- | Add a backtrack point. If the thread isn't runnable, add all
|
||||
-- runnable threads.
|
||||
lBacktrack :: BacktrackFunc
|
||||
lBacktrack = backtrackAt (const False) False
|
||||
lBacktrack = backtrackAt (\_ _ -> False)
|
||||
|
||||
-------------------------------------------------------------------------------
|
||||
-- DPOR
|
||||
@ -313,7 +306,7 @@ sctBounded memtype bf backtrack conc = go initialState where
|
||||
|
||||
if schedIgnore s
|
||||
then go newDPOR
|
||||
else ((res, trace):) <$> go (pruneCommits $ addBacktracks bpoints newDPOR)
|
||||
else ((res, trace):) <$> go (addBacktracks bpoints newDPOR)
|
||||
|
||||
Nothing -> pure []
|
||||
|
||||
@ -332,32 +325,11 @@ sctBounded memtype bf backtrack conc = go initialState where
|
||||
-- Incorporate the new backtracking steps into the DPOR tree.
|
||||
addBacktracks = incorporateBacktrackSteps bf
|
||||
|
||||
-------------------------------------------------------------------------------
|
||||
-- Post-processing
|
||||
|
||||
-- | Remove commits from the todo sets where every other action will
|
||||
-- result in a write barrier (and so a commit) occurring.
|
||||
--
|
||||
-- To get the benefit from this, do not execute commit actions from
|
||||
-- the todo set until there are no other choises.
|
||||
pruneCommits :: DPOR -> DPOR
|
||||
pruneCommits bpor
|
||||
| not onlycommits || not alldonesync = go bpor
|
||||
| otherwise = go bpor { dporTodo = M.empty }
|
||||
|
||||
where
|
||||
go b = b { dporDone = pruneCommits <$> dporDone bpor }
|
||||
|
||||
onlycommits = all (<initialThread) . M.keys $ dporTodo bpor
|
||||
alldonesync = all barrier . M.elems $ dporDone bpor
|
||||
|
||||
barrier = isBarrier . simplifyAction . fromJust . dporAction
|
||||
|
||||
-------------------------------------------------------------------------------
|
||||
-- Dependency function
|
||||
|
||||
-- | Check if an action is dependent on another.
|
||||
dependent :: MemType -> DepState -> (ThreadId, ThreadAction) -> (ThreadId, ThreadAction) -> Bool
|
||||
dependent :: MemType -> DepState -> ThreadId -> ThreadAction -> ThreadId -> ThreadAction -> Bool
|
||||
-- This is basically the same as 'dependent'', but can make use of the
|
||||
-- additional information in a 'ThreadAction' to make different
|
||||
-- decisions in a few cases:
|
||||
@ -381,14 +353,14 @@ dependent :: MemType -> DepState -> (ThreadId, ThreadAction) -> (ThreadId, Threa
|
||||
-- - Dependency of STM transactions can be /greatly/ improved here,
|
||||
-- as the 'Lookahead' does not know which @TVar@s will be touched,
|
||||
-- and so has to assume all transactions are dependent.
|
||||
dependent _ _ (_, SetNumCapabilities a) (_, GetNumCapabilities b) = a /= b
|
||||
dependent _ ds (_, ThrowTo t) (t2, a) = t == t2 && canInterrupt ds t2 a
|
||||
dependent memtype ds (t1, a1) (t2, a2) = case rewind a2 of
|
||||
dependent _ _ _ (SetNumCapabilities a) _ (GetNumCapabilities b) = a /= b
|
||||
dependent _ ds _ (ThrowTo t) t2 a = t == t2 && canInterrupt ds t2 a
|
||||
dependent memtype ds t1 a1 t2 a2 = case rewind a2 of
|
||||
Just l2
|
||||
| isSTM a1 && isSTM a2
|
||||
-> not . S.null $ tvarsOf a1 `S.intersection` tvarsOf a2
|
||||
| not (isBlock a1 && isBarrier (simplifyLookahead l2)) ->
|
||||
dependent' memtype ds (t1, a1) (t2, l2)
|
||||
dependent' memtype ds t1 a1 t2 l2
|
||||
_ -> dependentActions memtype ds (simplifyAction a1) (simplifyAction a2)
|
||||
|
||||
where
|
||||
@ -400,8 +372,8 @@ dependent memtype ds (t1, a1) (t2, a2) = case rewind a2 of
|
||||
--
|
||||
-- Termination of the initial thread is handled specially in the DPOR
|
||||
-- implementation.
|
||||
dependent' :: MemType -> DepState -> (ThreadId, ThreadAction) -> (ThreadId, Lookahead) -> Bool
|
||||
dependent' memtype ds (t1, a1) (t2, l2) = case (a1, l2) of
|
||||
dependent' :: MemType -> DepState -> ThreadId -> ThreadAction -> ThreadId -> Lookahead -> Bool
|
||||
dependent' memtype ds t1 a1 t2 l2 = case (a1, l2) of
|
||||
-- Worst-case assumption: all IO is dependent.
|
||||
(LiftIO, WillLiftIO) -> True
|
||||
|
||||
@ -496,6 +468,7 @@ yieldCount tid ts l = go initialThread ts where
|
||||
| t == tid && willYield l = 1
|
||||
| otherwise = 0
|
||||
|
||||
{-# INLINE go' #-}
|
||||
go' t t' act rest
|
||||
| t == tid && didYield act = 1 + go t' rest
|
||||
| otherwise = go t' rest
|
||||
@ -505,10 +478,14 @@ yieldCount tid ts l = go initialThread ts where
|
||||
maxYieldCountDiff :: [(Decision, ThreadAction)]
|
||||
-> Lookahead
|
||||
-> Int
|
||||
maxYieldCountDiff ts l = maximum yieldCountDiffs where
|
||||
yieldsBy tid = yieldCount tid ts l
|
||||
yieldCounts = [yieldsBy tid | tid <- nub $ allTids ts]
|
||||
yieldCountDiffs = [y1 - y2 | y1 <- yieldCounts, y2 <- yieldCounts]
|
||||
maxYieldCountDiff ts l = go 0 yieldCounts where
|
||||
go m (yc:ycs) =
|
||||
let m' = m `max` foldl' (go' yc) 0 ycs
|
||||
in go m' ycs
|
||||
go m [] = m
|
||||
go' yc0 m yc = m `max` abs (yc0 - yc)
|
||||
|
||||
yieldCounts = [yieldCount t ts l | t <- allTids ts]
|
||||
|
||||
-- All the threads created during the lifetime of the system.
|
||||
allTids ((_, act):rest) =
|
||||
|
@ -11,14 +11,14 @@
|
||||
-- interface of this library.
|
||||
module Test.DejaFu.SCT.Internal where
|
||||
|
||||
import Control.DeepSeq (NFData(..), force)
|
||||
import Control.Exception (MaskingState(..))
|
||||
import Data.Char (ord)
|
||||
import Data.List (foldl', intercalate, partition, sortBy)
|
||||
import Data.Function (on)
|
||||
import qualified Data.Foldable as F
|
||||
import Data.List (intercalate, nubBy, partition, sortOn)
|
||||
import Data.List.NonEmpty (NonEmpty(..), toList)
|
||||
import Data.Ord (Down(..), comparing)
|
||||
import Data.Map.Strict (Map)
|
||||
import Data.Maybe (fromJust, isJust, isNothing, mapMaybe)
|
||||
import Data.Maybe (catMaybes, fromJust, isNothing)
|
||||
import qualified Data.Map.Strict as M
|
||||
import Data.Set (Set)
|
||||
import qualified Data.Set as S
|
||||
@ -51,16 +51,7 @@ data DPOR = DPOR
|
||||
, dporAction :: Maybe ThreadAction
|
||||
-- ^ What happened at this step. This will be 'Nothing' at the root,
|
||||
-- 'Just' everywhere else.
|
||||
}
|
||||
|
||||
instance NFData DPOR where
|
||||
rnf dpor = rnf ( dporRunnable dpor
|
||||
, dporTodo dpor
|
||||
, dporDone dpor
|
||||
, dporSleep dpor
|
||||
, dporTaken dpor
|
||||
, dporAction dpor
|
||||
)
|
||||
} deriving Show
|
||||
|
||||
-- | One step of the execution, including information for backtracking
|
||||
-- purposes. This backtracking information is used to generate new
|
||||
@ -68,7 +59,9 @@ instance NFData DPOR where
|
||||
data BacktrackStep = BacktrackStep
|
||||
{ bcktThreadid :: ThreadId
|
||||
-- ^ The thread running at this step
|
||||
, bcktDecision :: (Decision, ThreadAction)
|
||||
, bcktDecision :: Decision
|
||||
-- ^ What was decided at this step.
|
||||
, bcktAction :: ThreadAction
|
||||
-- ^ What happened at this step.
|
||||
, bcktRunnable :: Map ThreadId Lookahead
|
||||
-- ^ The threads runnable at this step
|
||||
@ -77,15 +70,7 @@ data BacktrackStep = BacktrackStep
|
||||
-- alternatives were added conservatively due to the bound.
|
||||
, bcktState :: DepState
|
||||
-- ^ Some domain-specific state at this point.
|
||||
}
|
||||
|
||||
instance NFData BacktrackStep where
|
||||
rnf b = rnf ( bcktThreadid b
|
||||
, bcktDecision b
|
||||
, bcktRunnable b
|
||||
, bcktBacktracks b
|
||||
, bcktState b
|
||||
)
|
||||
} deriving Show
|
||||
|
||||
-- | Initial DPOR state, given an initial thread ID. This initial
|
||||
-- thread should exist and be runnable at the start of execution.
|
||||
@ -127,24 +112,17 @@ findSchedulePrefix predicate idx dpor0
|
||||
(ts, c, slp) = allPrefixes !! i
|
||||
in Just (ts, c, slp, g)
|
||||
where
|
||||
allPrefixes = go (initialDPORThread dpor0) dpor0
|
||||
allPrefixes = go dpor0
|
||||
|
||||
go tid dpor =
|
||||
go dpor =
|
||||
-- All the possible prefix traces from this point, with
|
||||
-- updated DPOR subtrees if taken from the done list.
|
||||
let prefixes = concatMap go' (M.toList $ dporDone dpor) ++ here dpor
|
||||
-- Sort by number of preemptions, in descending order.
|
||||
cmp = Down . preEmps tid dpor . (\(a,_,_) -> a)
|
||||
sorted = sortBy (comparing cmp) prefixes
|
||||
let prefixes = here dpor : map go' (M.toList $ dporDone dpor)
|
||||
in case concatPartition (\(t:_,_,_) -> predicate t) prefixes of
|
||||
([], choices) -> choices
|
||||
(choices, _) -> choices
|
||||
|
||||
in if null prefixes
|
||||
then []
|
||||
else case partition (\(t:_,_,_) -> predicate t) sorted of
|
||||
([], []) -> err "findSchedulePrefix" "empty prefix list!"
|
||||
([], choices) -> choices
|
||||
(choices, _) -> choices
|
||||
|
||||
go' (tid, dpor) = (\(ts,c,slp) -> (tid:ts,c,slp)) <$> go tid dpor
|
||||
go' (tid, dpor) = (\(ts,c,slp) -> (tid:ts,c,slp)) <$> go dpor
|
||||
|
||||
-- Prefix traces terminating with a to-do decision at this point.
|
||||
here dpor = [([t], c, sleeps dpor) | (t, c) <- M.toList $ dporTodo dpor]
|
||||
@ -154,16 +132,10 @@ findSchedulePrefix predicate idx dpor0
|
||||
-- explored.
|
||||
sleeps dpor = dporSleep dpor `M.union` dporTaken dpor
|
||||
|
||||
-- The number of pre-emptive context switches
|
||||
preEmps tid dpor (t:ts) =
|
||||
let rest = preEmps t (fromJust . M.lookup t $ dporDone dpor) ts
|
||||
in if tid `S.member` dporRunnable dpor then 1 + rest else rest
|
||||
preEmps _ _ [] = 0::Int
|
||||
|
||||
-- | Add a new trace to the tree, creating a new subtree branching off
|
||||
-- at the point where the \"to-do\" decision was made.
|
||||
incorporateTrace
|
||||
:: (DepState -> (ThreadId, ThreadAction) -> (ThreadId, ThreadAction) -> Bool)
|
||||
:: (DepState -> ThreadId -> ThreadAction -> ThreadId -> ThreadAction -> Bool)
|
||||
-- ^ Dependency function
|
||||
-> Bool
|
||||
-- ^ Whether the \"to-do\" point which was used to create this new
|
||||
@ -176,7 +148,7 @@ incorporateTrace
|
||||
incorporateTrace dependency conservative trace dpor0 = grow initialDepState (initialDPORThread dpor0) trace dpor0 where
|
||||
grow state tid trc@((d, _, a):rest) dpor =
|
||||
let tid' = tidOf tid d
|
||||
state' = updateDepState state (tid', a)
|
||||
state' = updateDepState state tid' a
|
||||
in case M.lookup tid' (dporDone dpor) of
|
||||
Just dpor' ->
|
||||
let done = M.insert tid' (grow state' tid' rest dpor') (dporDone dpor)
|
||||
@ -193,8 +165,8 @@ incorporateTrace dependency conservative trace dpor0 = grow initialDepState (ini
|
||||
|
||||
-- Construct a new subtree corresponding to a trace suffix.
|
||||
subtree state tid sleep ((_, _, a):rest) =
|
||||
let state' = updateDepState state (tid, a)
|
||||
sleep' = M.filterWithKey (\t a' -> not $ dependency state' (tid, a) (t,a')) sleep
|
||||
let state' = updateDepState state tid a
|
||||
sleep' = M.filterWithKey (\t a' -> not $ dependency state' tid a t a') sleep
|
||||
in DPOR
|
||||
{ dporRunnable = S.fromList $ case rest of
|
||||
((_, runnable, _):_) -> map fst runnable
|
||||
@ -225,7 +197,7 @@ incorporateTrace dependency conservative trace dpor0 = grow initialDepState (ini
|
||||
-- runnable, a dependency is imposed between this final action and
|
||||
-- everything else.
|
||||
findBacktrackSteps
|
||||
:: (DepState -> (ThreadId, ThreadAction) -> (ThreadId, Lookahead) -> Bool)
|
||||
:: (DepState -> ThreadId -> ThreadAction -> ThreadId -> Lookahead -> Bool)
|
||||
-- ^ Dependency function.
|
||||
-> BacktrackFunc
|
||||
-- ^ Backtracking function. Given a list of backtracking points, and
|
||||
@ -244,17 +216,16 @@ findBacktrackSteps
|
||||
-> Trace
|
||||
-- ^ The execution trace.
|
||||
-> [BacktrackStep]
|
||||
findBacktrackSteps _ _ _ bcktrck
|
||||
| Sq.null bcktrck = const []
|
||||
findBacktrackSteps dependency backtrack boundKill bcktrck = go initialDepState S.empty initialThread [] (Sq.viewl bcktrck) where
|
||||
findBacktrackSteps dependency backtrack boundKill = go initialDepState S.empty initialThread [] . F.toList where
|
||||
-- Walk through the traces one step at a time, building up a list of
|
||||
-- new backtracking points.
|
||||
go state allThreads tid bs ((e,i):<is) ((d,_,a):ts) =
|
||||
go state allThreads tid bs ((e,i):is) ((d,_,a):ts) =
|
||||
let tid' = tidOf tid d
|
||||
state' = updateDepState state (tid', a)
|
||||
state' = updateDepState state tid' a
|
||||
this = BacktrackStep
|
||||
{ bcktThreadid = tid'
|
||||
, bcktDecision = (d, a)
|
||||
, bcktDecision = d
|
||||
, bcktAction = a
|
||||
, bcktRunnable = M.fromList . toList $ e
|
||||
, bcktBacktracks = M.fromList $ map (\i' -> (i', False)) i
|
||||
, bcktState = state'
|
||||
@ -263,30 +234,41 @@ findBacktrackSteps dependency backtrack boundKill bcktrck = go initialDepState S
|
||||
runnable = S.fromList (M.keys $ bcktRunnable this)
|
||||
allThreads' = allThreads `S.union` runnable
|
||||
killsEarly = null ts && boundKill
|
||||
in go state' allThreads' tid' bs' (Sq.viewl is) ts
|
||||
in go state' allThreads' tid' bs' is ts
|
||||
go _ _ _ bs _ _ = bs
|
||||
|
||||
-- Find the prior actions dependent with this one and add
|
||||
-- backtracking points.
|
||||
doBacktrack killsEarly allThreads enabledThreads bs =
|
||||
let tagged = reverse $ zip [0..] bs
|
||||
idxs = [ (head is, u)
|
||||
idxs = [ (head is, False, u)
|
||||
| (u, n) <- enabledThreads
|
||||
, v <- S.toList allThreads
|
||||
, u /= v
|
||||
, let is = idxs' u n v tagged
|
||||
, not $ null is]
|
||||
|
||||
idxs' u n v = mapMaybe go' where
|
||||
go' (i, b)
|
||||
idxs' u n v = catMaybes . go' True where
|
||||
{-# INLINE go' #-}
|
||||
go' final ((i,b):rest)
|
||||
-- Don't cross subconcurrency boundaries
|
||||
| isSubC final b = []
|
||||
-- If this is the final action in the trace and the
|
||||
-- execution was killed due to nothing being within bounds
|
||||
-- (@killsEarly == True@) assume worst-case dependency.
|
||||
| bcktThreadid b == v && (killsEarly || isDependent b) = Just i
|
||||
| otherwise = Nothing
|
||||
| bcktThreadid b == v && (killsEarly || isDependent b) = Just i : go' False rest
|
||||
| otherwise = go' False rest
|
||||
go' _ [] = []
|
||||
|
||||
isDependent b = dependency (bcktState b) (bcktThreadid b, snd $ bcktDecision b) (u, n)
|
||||
in foldl' (\b (i, u) -> backtrack b i u) bs idxs
|
||||
{-# INLINE isSubC #-}
|
||||
isSubC final b = case bcktAction b of
|
||||
Stop -> not final && bcktThreadid b == initialThread
|
||||
Subconcurrency -> bcktThreadid b == initialThread
|
||||
_ -> False
|
||||
|
||||
{-# INLINE isDependent #-}
|
||||
isDependent b = dependency (bcktState b) (bcktThreadid b) (bcktAction b) u n
|
||||
in backtrack bs idxs
|
||||
|
||||
-- | Add new backtracking points, if they have not already been
|
||||
-- visited, fit into the bound, and aren't in the sleep set.
|
||||
@ -302,10 +284,9 @@ incorporateBacktrackSteps bv = go Nothing [] where
|
||||
go priorTid pref (b:bs) bpor =
|
||||
let bpor' = doBacktrack priorTid pref b bpor
|
||||
tid = bcktThreadid b
|
||||
pref' = pref ++ [bcktDecision b]
|
||||
pref' = pref ++ [(bcktDecision b, bcktAction b)]
|
||||
child = go (Just tid) pref' bs . fromJust $ M.lookup tid (dporDone bpor)
|
||||
in bpor' { dporDone = M.insert tid child $ dporDone bpor' }
|
||||
|
||||
go _ _ [] bpor = bpor
|
||||
|
||||
doBacktrack priorTid pref b bpor =
|
||||
@ -343,16 +324,7 @@ data SchedState = SchedState
|
||||
, schedDepState :: DepState
|
||||
-- ^ State used by the dependency function to determine when to
|
||||
-- remove decisions from the sleep set.
|
||||
}
|
||||
|
||||
instance NFData SchedState where
|
||||
rnf s = rnf ( schedSleep s
|
||||
, schedPrefix s
|
||||
, schedBPoints s
|
||||
, schedIgnore s
|
||||
, schedBoundKill s
|
||||
, schedDepState s
|
||||
)
|
||||
} deriving Show
|
||||
|
||||
-- | Initial scheduler state for a given prefix
|
||||
initialSchedState :: Map ThreadId ThreadAction
|
||||
@ -378,52 +350,60 @@ type BoundFunc
|
||||
-- | A backtracking step is a point in the execution where another
|
||||
-- decision needs to be made, in order to explore interesting new
|
||||
-- schedules. A backtracking /function/ takes the steps identified so
|
||||
-- far and a point and a thread to backtrack to, and inserts at least
|
||||
-- that backtracking point. More may be added to compensate for the
|
||||
-- effects of the bounding function. For example, under pre-emption
|
||||
-- bounding a conservative backtracking point is added at the prior
|
||||
-- context switch.
|
||||
-- far and a list of points and thread at that point to backtrack
|
||||
-- to. More points be added to compensate for the effects of the
|
||||
-- bounding function. For example, under pre-emption bounding a
|
||||
-- conservative backtracking point is added at the prior context
|
||||
-- switch. The bool is whether the point is conservative. Conservative
|
||||
-- points are always explored, whereas non-conservative ones might be
|
||||
-- skipped based on future information.
|
||||
--
|
||||
-- In general, a backtracking function should identify one or more
|
||||
-- backtracking points, and then use @backtrackAt@ to do the actual
|
||||
-- work.
|
||||
type BacktrackFunc
|
||||
= [BacktrackStep] -> Int -> ThreadId -> [BacktrackStep]
|
||||
= [BacktrackStep] -> [(Int, Bool, ThreadId)] -> [BacktrackStep]
|
||||
|
||||
-- | Add a backtracking point. If the thread isn't runnable, add all
|
||||
-- runnable threads. If the backtracking point is already present,
|
||||
-- don't re-add it UNLESS this would make it conservative.
|
||||
backtrackAt
|
||||
:: (BacktrackStep -> Bool)
|
||||
:: (ThreadId -> BacktrackStep -> Bool)
|
||||
-- ^ If this returns @True@, backtrack to all runnable threads,
|
||||
-- rather than just the given thread.
|
||||
-> Bool
|
||||
-- ^ Is this backtracking point conservative? Conservative points
|
||||
-- are always explored, whereas non-conservative ones might be
|
||||
-- skipped based on future information.
|
||||
-> BacktrackFunc
|
||||
backtrackAt toAll conservative bs i tid = go bs i where
|
||||
go bx@(b:rest) 0
|
||||
backtrackAt toAll bs0 = backtrackAt' . nubBy ((==) `on` fst') . sortOn fst' where
|
||||
fst' (x,_,_) = x
|
||||
|
||||
backtrackAt' ((i,c,t):is) = go i bs0 i c t is
|
||||
backtrackAt' [] = bs0
|
||||
|
||||
go i0 (b:bs) 0 c tid is
|
||||
-- If the backtracking point is already present, don't re-add it,
|
||||
-- UNLESS this would force it to backtrack (it's conservative)
|
||||
-- where before it might not.
|
||||
| not (toAll b) && tid `M.member` bcktRunnable b =
|
||||
| not (toAll tid b) && tid `M.member` bcktRunnable b =
|
||||
let val = M.lookup tid $ bcktBacktracks b
|
||||
in if isNothing val || (val == Just False && conservative)
|
||||
then b { bcktBacktracks = backtrackTo b } : rest
|
||||
else bx
|
||||
|
||||
b' = if isNothing val || (val == Just False && c)
|
||||
then b { bcktBacktracks = backtrackTo tid c b }
|
||||
else b
|
||||
in b' : case is of
|
||||
((i',c',t'):is') -> go i' bs (i'-i0-1) c' t' is'
|
||||
[] -> bs
|
||||
-- Otherwise just backtrack to everything runnable.
|
||||
| otherwise = b { bcktBacktracks = backtrackAll b } : rest
|
||||
|
||||
go (b:rest) n = b : go rest (n-1)
|
||||
go [] _ = error "backtrackAt: Ran out of schedule whilst backtracking!"
|
||||
| otherwise =
|
||||
let b' = b { bcktBacktracks = backtrackAll c b }
|
||||
in b' : case is of
|
||||
((i',c',t'):is') -> go i' bs (i'-i0-1) c' t' is'
|
||||
[] -> bs
|
||||
go i0 (b:bs) i c tid is = b : go i0 bs (i-1) c tid is
|
||||
go _ [] _ _ _ _ = err "backtrackAt" "ran out of schedule whilst backtracking!"
|
||||
|
||||
-- Backtrack to a single thread
|
||||
backtrackTo = M.insert tid conservative . bcktBacktracks
|
||||
backtrackTo tid c = M.insert tid c . bcktBacktracks
|
||||
|
||||
-- Backtrack to all runnable threads
|
||||
backtrackAll = M.map (const conservative) . bcktRunnable
|
||||
backtrackAll c = M.map (const c) . bcktRunnable
|
||||
|
||||
-- | DPOR scheduler: takes a list of decisions, and maintains a trace
|
||||
-- including the runnable threads, and the alternative choices allowed
|
||||
@ -433,17 +413,14 @@ backtrackAt toAll conservative bs i tid = go bs i where
|
||||
-- the prior thread if it's (1) still runnable and (2) hasn't just
|
||||
-- yielded. Furthermore, threads which /will/ yield are ignored in
|
||||
-- preference of those which will not.
|
||||
--
|
||||
-- This forces full evaluation of the result every step, to avoid any
|
||||
-- possible space leaks.
|
||||
dporSched
|
||||
:: (DepState -> (ThreadId, ThreadAction) -> (ThreadId, ThreadAction) -> Bool)
|
||||
:: (DepState -> ThreadId -> ThreadAction -> ThreadId -> ThreadAction -> Bool)
|
||||
-- ^ Dependency function.
|
||||
-> BoundFunc
|
||||
-- ^ Bound function: returns true if that schedule prefix terminated
|
||||
-- with the lookahead decision fits within the bound.
|
||||
-> Scheduler SchedState
|
||||
dporSched dependency inBound trc prior threads s = force schedule where
|
||||
dporSched dependency inBound trc prior threads s = schedule where
|
||||
-- Pick a thread to run.
|
||||
schedule = case schedPrefix s of
|
||||
-- If there is a decision available, make it
|
||||
@ -455,7 +432,7 @@ dporSched dependency inBound trc prior threads s = force schedule where
|
||||
[] ->
|
||||
let choices = restrictToBound initialise
|
||||
checkDep t a = case prior of
|
||||
Just (tid, act) -> dependency (schedDepState s) (tid, act) (t, a)
|
||||
Just (tid, act) -> dependency (schedDepState s) tid act t a
|
||||
Nothing -> False
|
||||
ssleep' = M.filterWithKey (\t a -> not $ checkDep t a) $ schedSleep s
|
||||
choices' = filter (`notElem` M.keys ssleep') choices
|
||||
@ -470,7 +447,7 @@ dporSched dependency inBound trc prior threads s = force schedule where
|
||||
{ schedBPoints = schedBPoints s |> (threads, rest)
|
||||
, schedDepState = nextDepState
|
||||
}
|
||||
nextDepState = let ds = schedDepState s in maybe ds (updateDepState ds) prior
|
||||
nextDepState = let ds = schedDepState s in maybe ds (uncurry $ updateDepState ds) prior
|
||||
|
||||
-- Pick a new thread to run, not considering bounds. Choose the
|
||||
-- current thread if available and it hasn't just yielded, otherwise
|
||||
@ -537,11 +514,7 @@ data DepState = DepState
|
||||
-- the masking state is assumed to be @Unmasked@. This nicely
|
||||
-- provides compatibility with dpor-0.1, where the thread IDs are
|
||||
-- not available.
|
||||
}
|
||||
|
||||
instance NFData DepState where
|
||||
-- Cheats: 'MaskingState' has no 'NFData' instance.
|
||||
rnf ds = rnf (depCRState ds, M.keys (depMaskState ds))
|
||||
} deriving (Eq, Show)
|
||||
|
||||
-- | Initial dependency state.
|
||||
initialDepState :: DepState
|
||||
@ -549,8 +522,8 @@ initialDepState = DepState M.empty M.empty
|
||||
|
||||
-- | Update the 'CRef' buffer state with the action that has just
|
||||
-- happened.
|
||||
updateDepState :: DepState -> (ThreadId, ThreadAction) -> DepState
|
||||
updateDepState depstate (tid, act) = DepState
|
||||
updateDepState :: DepState -> ThreadId -> ThreadAction -> DepState
|
||||
updateDepState depstate tid act = DepState
|
||||
{ depCRState = updateCRState act $ depCRState depstate
|
||||
, depMaskState = updateMaskState tid act $ depMaskState depstate
|
||||
}
|
||||
@ -698,3 +671,17 @@ toDotFiltered check showTid showAct = digraph . go "L" where
|
||||
-- | Internal errors.
|
||||
err :: String -> String -> a
|
||||
err func msg = error (func ++ ": (internal error) " ++ msg)
|
||||
|
||||
-- | A combination of 'partition' and 'concat'.
|
||||
concatPartition :: (a -> Bool) -> [[a]] -> ([a], [a])
|
||||
{-# INLINE concatPartition #-}
|
||||
-- note: `foldr (flip (foldr select))` is slow, as is `foldl (foldl
|
||||
-- select))`, and `foldl'` variants. The sweet spot seems to be `foldl
|
||||
-- (foldr select)` for some reason I don't really understand.
|
||||
concatPartition p = foldl (foldr select) ([], []) where
|
||||
-- Lazy pattern matching, got this trick from the 'partition'
|
||||
-- implementation. This reduces allocation fairly significantly; I
|
||||
-- do not know why.
|
||||
select a ~(ts, fs)
|
||||
| p a = (a:ts, fs)
|
||||
| otherwise = (ts, a:fs)
|
||||
|
@ -96,7 +96,6 @@ library
|
||||
build-depends: base >=4.8 && <5
|
||||
, concurrency >=1.0 && <1.1
|
||||
, containers >=0.5 && <0.6
|
||||
, deepseq >=1.3 && <1.5
|
||||
, exceptions >=0.7 && <0.9
|
||||
, monad-loops >=0.4 && <0.5
|
||||
, mtl >=2.2 && <2.3
|
||||
|
Loading…
Reference in New Issue
Block a user