Make things monomorphic over the STM type.

Only `STMLike n r` is used, so passing around the type variable and
STM runner is just noise.
This commit is contained in:
Michael Walker 2017-02-02 14:45:24 +00:00
parent 1b0adc541d
commit 5114fd37a1
5 changed files with 99 additions and 102 deletions

View File

@ -66,7 +66,7 @@ import Test.DejaFu.STM
{-# ANN module ("HLint: ignore Avoid lambda" :: String) #-} {-# ANN module ("HLint: ignore Avoid lambda" :: String) #-}
{-# ANN module ("HLint: ignore Use const" :: String) #-} {-# ANN module ("HLint: ignore Use const" :: String) #-}
newtype Conc n r a = C { unC :: M n r (STMLike n r) a } deriving (Functor, Applicative, Monad) newtype Conc n r a = C { unC :: M n r a } deriving (Functor, Applicative, Monad)
-- | A 'MonadConc' implementation using @ST@, this should be preferred -- | A 'MonadConc' implementation using @ST@, this should be preferred
-- if you do not need 'liftIO'. -- if you do not need 'liftIO'.
@ -75,10 +75,10 @@ type ConcST t = Conc (ST t) (STRef t)
-- | A 'MonadConc' implementation using @IO@. -- | A 'MonadConc' implementation using @IO@.
type ConcIO = Conc IO IORef type ConcIO = Conc IO IORef
toConc :: ((a -> Action n r (STMLike n r)) -> Action n r (STMLike n r)) -> Conc n r a toConc :: ((a -> Action n r) -> Action n r) -> Conc n r a
toConc = C . cont toConc = C . cont
wrap :: (M n r (STMLike n r) a -> M n r (STMLike n r) a) -> Conc n r a -> Conc n r a wrap :: (M n r a -> M n r a) -> Conc n r a -> Conc n r a
wrap f = C . f . unC wrap f = C . f . unC
instance IO.MonadIO ConcIO where instance IO.MonadIO ConcIO where
@ -180,7 +180,7 @@ runConcurrent :: MonadRef r n
-> Conc n r a -> Conc n r a
-> n (Either Failure a, s, Trace) -> n (Either Failure a, s, Trace)
runConcurrent sched memtype s ma = do runConcurrent sched memtype s ma = do
(res, s', trace) <- runConcurrency runTransaction sched memtype s (unC ma) (res, s', trace) <- runConcurrency sched memtype s (unC ma)
pure (res, s', reverse trace) pure (res, s', reverse trace)
-- | Run a concurrent computation and return its result. -- | Run a concurrent computation and return its result.

View File

@ -28,7 +28,7 @@ import Test.DejaFu.Conc.Internal.Common
import Test.DejaFu.Conc.Internal.Memory import Test.DejaFu.Conc.Internal.Memory
import Test.DejaFu.Conc.Internal.Threading import Test.DejaFu.Conc.Internal.Threading
import Test.DejaFu.Schedule import Test.DejaFu.Schedule
import Test.DejaFu.STM (Result(..)) import Test.DejaFu.STM (Result(..), runTransaction)
{-# ANN module ("HLint: ignore Use record patterns" :: String) #-} {-# ANN module ("HLint: ignore Use record patterns" :: String) #-}
{-# ANN module ("HLint: ignore Use const" :: String) #-} {-# ANN module ("HLint: ignore Use const" :: String) #-}
@ -41,20 +41,18 @@ import Test.DejaFu.STM (Result(..))
-- final state of the scheduler, and an execution trace (in reverse -- final state of the scheduler, and an execution trace (in reverse
-- order). -- order).
runConcurrency :: MonadRef r n runConcurrency :: MonadRef r n
=> (forall x. s x -> IdSource -> n (Result x, IdSource, TTrace)) => Scheduler g
-> Scheduler g
-> MemType -> MemType
-> g -> g
-> M n r s a -> M n r a
-> n (Either Failure a, g, Trace) -> n (Either Failure a, g, Trace)
runConcurrency runstm sched memtype g ma = do runConcurrency sched memtype g ma = do
ref <- newRef Nothing ref <- newRef Nothing
let c = runCont ma (AStop . writeRef ref . Just . Right) let c = runCont ma (AStop . writeRef ref . Just . Right)
let threads = launch' Unmasked initialThread (const c) M.empty let threads = launch' Unmasked initialThread (const c) M.empty
(g', trace) <- runThreads runstm (g', trace) <- runThreads sched
sched
memtype memtype
g g
threads threads
@ -71,9 +69,9 @@ runConcurrency runstm sched memtype g ma = do
-- efficient to prepend to a list than append. As this function isn't -- efficient to prepend to a list than append. As this function isn't
-- exposed to users of the library, this is just an internal gotcha to -- exposed to users of the library, this is just an internal gotcha to
-- watch out for. -- watch out for.
runThreads :: MonadRef r n => (forall x. s x -> IdSource -> n (Result x, IdSource, TTrace)) runThreads :: MonadRef r n
-> Scheduler g -> MemType -> g -> Threads n r s -> IdSource -> r (Maybe (Either Failure a)) -> n (g, Trace) => Scheduler g -> MemType -> g -> Threads n r -> IdSource -> r (Maybe (Either Failure a)) -> n (g, Trace)
runThreads runstm sched memtype origg origthreads idsrc ref = go idsrc [] Nothing origg origthreads emptyBuffer 2 where runThreads sched memtype origg origthreads idsrc ref = go idsrc [] Nothing origg origthreads emptyBuffer 2 where
go idSource sofar prior g threads wb caps go idSource sofar prior g threads wb caps
| isTerminated = stop g | isTerminated = stop g
| isDeadlocked = die g Deadlock | isDeadlocked = die g Deadlock
@ -82,7 +80,7 @@ runThreads runstm sched memtype origg origthreads idsrc ref = go idsrc [] Nothin
| isNonexistant = die g' InternalError | isNonexistant = die g' InternalError
| isBlocked = die g' InternalError | isBlocked = die g' InternalError
| otherwise = do | otherwise = do
stepped <- stepThread runstm sched memtype g (_continuation $ fromJust thread) idSource chosen threads wb caps stepped <- stepThread sched memtype g (_continuation $ fromJust thread) idSource chosen threads wb caps
case stepped of case stepped of
Right (threads', idSource', act, wb', caps', mg') -> loop threads' idSource' act (fromMaybe g' mg') wb' caps' Right (threads', idSource', act, wb', caps', mg') -> loop threads' idSource' act (fromMaybe g' mg') wb' caps'
@ -138,29 +136,27 @@ runThreads runstm sched memtype origg origthreads idsrc ref = go idsrc [] Nothin
-- | Run a single thread one step, by dispatching on the type of -- | Run a single thread one step, by dispatching on the type of
-- 'Action'. -- 'Action'.
stepThread :: forall n r s g. MonadRef r n stepThread :: forall n r g. MonadRef r n
=> (forall x. s x -> IdSource -> n (Result x, IdSource, TTrace)) => Scheduler g
-- ^ Run a 'MonadSTM' transaction atomically.
-> Scheduler g
-- ^ The scheduler. -- ^ The scheduler.
-> MemType -> MemType
-- ^ The memory model -- ^ The memory model
-> g -> g
-- ^ The scheduler state. -- ^ The scheduler state.
-> Action n r s -> Action n r
-- ^ Action to step -- ^ Action to step
-> IdSource -> IdSource
-- ^ Source of fresh IDs -- ^ Source of fresh IDs
-> ThreadId -> ThreadId
-- ^ ID of the current thread -- ^ ID of the current thread
-> Threads n r s -> Threads n r
-- ^ Current state of threads -- ^ Current state of threads
-> WriteBuffer r -> WriteBuffer r
-- ^ @CRef@ write buffer -- ^ @CRef@ write buffer
-> Int -> Int
-- ^ The number of capabilities -- ^ The number of capabilities
-> n (Either Failure (Threads n r s, IdSource, Either (ThreadAction, Trace) ThreadAction, WriteBuffer r, Int, Maybe g)) -> n (Either Failure (Threads n r, IdSource, Either (ThreadAction, Trace) ThreadAction, WriteBuffer r, Int, Maybe g))
stepThread runstm sched memtype g action idSource tid threads wb caps = case action of stepThread sched memtype g action idSource tid threads wb caps = case action of
AFork n a b -> stepFork n a b AFork n a b -> stepFork n a b
AMyTId c -> stepMyTId c AMyTId c -> stepMyTId c
AGetNumCapabilities c -> stepGetNumCapabilities c AGetNumCapabilities c -> stepGetNumCapabilities c
@ -199,9 +195,9 @@ stepThread runstm sched memtype g action idSource tid threads wb caps = case act
-- Explicit type signature needed for GHC 8. Looks like the -- Explicit type signature needed for GHC 8. Looks like the
-- impredicative polymorphism checks got stronger. -- impredicative polymorphism checks got stronger.
stepFork :: String stepFork :: String
-> ((forall b. M n r s b -> M n r s b) -> Action n r s) -> ((forall b. M n r b -> M n r b) -> Action n r)
-> (ThreadId -> Action n r s) -> (ThreadId -> Action n r)
-> n (Either Failure (Threads n r s, IdSource, Either z ThreadAction, WriteBuffer r, Int, Maybe g)) -> n (Either Failure (Threads n r, IdSource, Either z ThreadAction, WriteBuffer r, Int, Maybe g))
stepFork n a b = return $ Right (goto (b newtid) tid threads', idSource', Right (Fork newtid), wb, caps, Nothing) where stepFork n a b = return $ Right (goto (b newtid) tid threads', idSource', Right (Fork newtid), wb, caps, Nothing) where
threads' = launch tid newtid a threads threads' = launch tid newtid a threads
(idSource', newtid) = nextTId n idSource (idSource', newtid) = nextTId n idSource
@ -308,7 +304,7 @@ stepThread runstm sched memtype g action idSource tid threads wb caps = case act
-- | Run a STM transaction atomically. -- | Run a STM transaction atomically.
stepAtom stm c = synchronised $ do stepAtom stm c = synchronised $ do
(res, idSource', trace) <- runstm stm idSource (res, idSource', trace) <- runTransaction stm idSource
case res of case res of
Success _ written val -> Success _ written val ->
let (threads', woken) = wake (OnTVar written) threads let (threads', woken) = wake (OnTVar written) threads
@ -362,9 +358,9 @@ stepThread runstm sched memtype g action idSource tid threads wb caps = case act
-- Explicit type sig necessary for checking in the prescence of -- Explicit type sig necessary for checking in the prescence of
-- 'umask', sadly. -- 'umask', sadly.
stepMasking :: MaskingState stepMasking :: MaskingState
-> ((forall b. M n r s b -> M n r s b) -> M n r s a) -> ((forall b. M n r b -> M n r b) -> M n r a)
-> (a -> Action n r s) -> (a -> Action n r)
-> n (Either Failure (Threads n r s, IdSource, Either z ThreadAction, WriteBuffer r, Int, Maybe g)) -> n (Either Failure (Threads n r, IdSource, Either z ThreadAction, WriteBuffer r, Int, Maybe g))
stepMasking m ma c = simple threads' $ SetMasking False m where stepMasking m ma c = simple threads' $ SetMasking False m where
a = runCont (ma umask) (AResetMask False False m' . c) a = runCont (ma umask) (AResetMask False False m' . c)
@ -412,7 +408,7 @@ stepThread runstm sched memtype g action idSource tid threads wb caps = case act
stepSubconcurrency ma c stepSubconcurrency ma c
| M.size threads > 1 = return (Left IllegalSubconcurrency) | M.size threads > 1 = return (Left IllegalSubconcurrency)
| otherwise = do | otherwise = do
(res, g', trace) <- runConcurrency runstm sched memtype g ma (res, g', trace) <- runConcurrency sched memtype g ma
return $ Right (goto (c res) tid threads, idSource, Left (Subconcurrency, trace), wb, caps, Just g') return $ Right (goto (c res) tid threads, idSource, Left (Subconcurrency, trace), wb, caps, Just g')
-- | Helper for actions which don't touch the 'IdSource' or -- | Helper for actions which don't touch the 'IdSource' or

View File

@ -18,6 +18,7 @@ import Data.Dynamic (Dynamic)
import Data.Map.Strict (Map) import Data.Map.Strict (Map)
import Data.List.NonEmpty (NonEmpty, fromList) import Data.List.NonEmpty (NonEmpty, fromList)
import Test.DejaFu.Common import Test.DejaFu.Common
import Test.DejaFu.STM (STMLike)
{-# ANN module ("HLint: ignore Use record patterns" :: String) #-} {-# ANN module ("HLint: ignore Use record patterns" :: String) #-}
@ -32,16 +33,16 @@ import Test.DejaFu.Common
-- current expression of threads and exception handlers very difficult -- current expression of threads and exception handlers very difficult
-- (perhaps even not possible without significant reworking), so I -- (perhaps even not possible without significant reworking), so I
-- abandoned the attempt. -- abandoned the attempt.
newtype M n r s a = M { runM :: (a -> Action n r s) -> Action n r s } newtype M n r a = M { runM :: (a -> Action n r) -> Action n r }
instance Functor (M n r s) where instance Functor (M n r) where
fmap f m = M $ \ c -> runM m (c . f) fmap f m = M $ \ c -> runM m (c . f)
instance Applicative (M n r s) where instance Applicative (M n r) where
pure x = M $ \c -> AReturn $ c x pure x = M $ \c -> AReturn $ c x
f <*> v = M $ \c -> runM f (\g -> runM v (c . g)) f <*> v = M $ \c -> runM f (\g -> runM v (c . g))
instance Monad (M n r s) where instance Monad (M n r) where
return = pure return = pure
m >>= k = M $ \c -> runM m (\x -> runM (k x) c) m >>= k = M $ \c -> runM m (\x -> runM (k x) c)
@ -82,11 +83,11 @@ data Ticket a = Ticket
} }
-- | Construct a continuation-passing operation from a function. -- | Construct a continuation-passing operation from a function.
cont :: ((a -> Action n r s) -> Action n r s) -> M n r s a cont :: ((a -> Action n r) -> Action n r) -> M n r a
cont = M cont = M
-- | Run a CPS computation with the given final computation. -- | Run a CPS computation with the given final computation.
runCont :: M n r s a -> (a -> Action n r s) -> Action n r s runCont :: M n r a -> (a -> Action n r) -> Action n r
runCont = runM runCont = runM
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
@ -96,51 +97,51 @@ runCont = runM
-- only occur as a result of an action, and they cover (most of) the -- only occur as a result of an action, and they cover (most of) the
-- primitives of the concurrency. 'spawn' is absent as it is -- primitives of the concurrency. 'spawn' is absent as it is
-- implemented in terms of 'newEmptyMVar', 'fork', and 'putMVar'. -- implemented in terms of 'newEmptyMVar', 'fork', and 'putMVar'.
data Action n r s = data Action n r =
AFork String ((forall b. M n r s b -> M n r s b) -> Action n r s) (ThreadId -> Action n r s) AFork String ((forall b. M n r b -> M n r b) -> Action n r) (ThreadId -> Action n r)
| AMyTId (ThreadId -> Action n r s) | AMyTId (ThreadId -> Action n r)
| AGetNumCapabilities (Int -> Action n r s) | AGetNumCapabilities (Int -> Action n r)
| ASetNumCapabilities Int (Action n r s) | ASetNumCapabilities Int (Action n r)
| forall a. ANewVar String (MVar r a -> Action n r s) | forall a. ANewVar String (MVar r a -> Action n r)
| forall a. APutVar (MVar r a) a (Action n r s) | forall a. APutVar (MVar r a) a (Action n r)
| forall a. ATryPutVar (MVar r a) a (Bool -> Action n r s) | forall a. ATryPutVar (MVar r a) a (Bool -> Action n r)
| forall a. AReadVar (MVar r a) (a -> Action n r s) | forall a. AReadVar (MVar r a) (a -> Action n r)
| forall a. ATakeVar (MVar r a) (a -> Action n r s) | forall a. ATakeVar (MVar r a) (a -> Action n r)
| forall a. ATryTakeVar (MVar r a) (Maybe a -> Action n r s) | forall a. ATryTakeVar (MVar r a) (Maybe a -> Action n r)
| forall a. ANewRef String a (CRef r a -> Action n r s) | forall a. ANewRef String a (CRef r a -> Action n r)
| forall a. AReadRef (CRef r a) (a -> Action n r s) | forall a. AReadRef (CRef r a) (a -> Action n r)
| forall a. AReadRefCas (CRef r a) (Ticket a -> Action n r s) | forall a. AReadRefCas (CRef r a) (Ticket a -> Action n r)
| forall a b. AModRef (CRef r a) (a -> (a, b)) (b -> Action n r s) | forall a b. AModRef (CRef r a) (a -> (a, b)) (b -> Action n r)
| forall a b. AModRefCas (CRef r a) (a -> (a, b)) (b -> Action n r s) | forall a b. AModRefCas (CRef r a) (a -> (a, b)) (b -> Action n r)
| forall a. AWriteRef (CRef r a) a (Action n r s) | forall a. AWriteRef (CRef r a) a (Action n r)
| forall a. ACasRef (CRef r a) (Ticket a) a ((Bool, Ticket a) -> Action n r s) | forall a. ACasRef (CRef r a) (Ticket a) a ((Bool, Ticket a) -> Action n r)
| forall e. Exception e => AThrow e | forall e. Exception e => AThrow e
| forall e. Exception e => AThrowTo ThreadId e (Action n r s) | forall e. Exception e => AThrowTo ThreadId e (Action n r)
| forall a e. Exception e => ACatching (e -> M n r s a) (M n r s a) (a -> Action n r s) | forall a e. Exception e => ACatching (e -> M n r a) (M n r a) (a -> Action n r)
| APopCatching (Action n r s) | APopCatching (Action n r)
| forall a. AMasking MaskingState ((forall b. M n r s b -> M n r s b) -> M n r s a) (a -> Action n r s) | forall a. AMasking MaskingState ((forall b. M n r b -> M n r b) -> M n r a) (a -> Action n r)
| AResetMask Bool Bool MaskingState (Action n r s) | AResetMask Bool Bool MaskingState (Action n r)
| AMessage Dynamic (Action n r s) | AMessage Dynamic (Action n r)
| forall a. AAtom (s a) (a -> Action n r s) | forall a. AAtom (STMLike n r a) (a -> Action n r)
| ALift (n (Action n r s)) | ALift (n (Action n r))
| AYield (Action n r s) | AYield (Action n r)
| AReturn (Action n r s) | AReturn (Action n r)
| ACommit ThreadId CRefId | ACommit ThreadId CRefId
| AStop (n ()) | AStop (n ())
| forall a. ASub (M n r s a) (Either Failure a -> Action n r s) | forall a. ASub (M n r a) (Either Failure a -> Action n r)
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
-- * Scheduling & Traces -- * Scheduling & Traces
-- | Look as far ahead in the given continuation as possible. -- | Look as far ahead in the given continuation as possible.
lookahead :: Action n r s -> NonEmpty Lookahead lookahead :: Action n r -> NonEmpty Lookahead
lookahead = fromList . lookahead' where lookahead = fromList . lookahead' where
lookahead' (AFork _ _ _) = [WillFork] lookahead' (AFork _ _ _) = [WillFork]
lookahead' (AMyTId _) = [WillMyThreadId] lookahead' (AMyTId _) = [WillMyThreadId]

View File

@ -126,7 +126,7 @@ writeBarrier (WriteBuffer wb) = mapM_ flush $ M.elems wb where
flush = mapM_ $ \(BufferedWrite _ cref a) -> writeImmediate cref a flush = mapM_ $ \(BufferedWrite _ cref a) -> writeImmediate cref a
-- | Add phantom threads to the thread list to commit pending writes. -- | Add phantom threads to the thread list to commit pending writes.
addCommitThreads :: WriteBuffer r -> Threads n r s -> Threads n r s addCommitThreads :: WriteBuffer r -> Threads n r -> Threads n r
addCommitThreads (WriteBuffer wb) ts = ts <> M.fromList phantoms where addCommitThreads (WriteBuffer wb) ts = ts <> M.fromList phantoms where
phantoms = [ (ThreadId Nothing $ negate tid, mkthread $ fromJust c) phantoms = [ (ThreadId Nothing $ negate tid, mkthread $ fromJust c)
| ((k, b), tid) <- zip (M.toList wb) [1..] | ((k, b), tid) <- zip (M.toList wb) [1..]
@ -136,41 +136,41 @@ addCommitThreads (WriteBuffer wb) ts = ts <> M.fromList phantoms where
go EmptyL = Nothing go EmptyL = Nothing
-- | Remove phantom threads. -- | Remove phantom threads.
delCommitThreads :: Threads n r s -> Threads n r s delCommitThreads :: Threads n r -> Threads n r
delCommitThreads = M.filterWithKey $ \k _ -> k >= initialThread delCommitThreads = M.filterWithKey $ \k _ -> k >= initialThread
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
-- * Manipulating @MVar@s -- * Manipulating @MVar@s
-- | Put into a @MVar@, blocking if full. -- | Put into a @MVar@, blocking if full.
putIntoMVar :: MonadRef r n => MVar r a -> a -> Action n r s putIntoMVar :: MonadRef r n => MVar r a -> a -> Action n r
-> ThreadId -> Threads n r s -> n (Bool, Threads n r s, [ThreadId]) -> ThreadId -> Threads n r -> n (Bool, Threads n r, [ThreadId])
putIntoMVar cvar a c = mutMVar True cvar a (const c) putIntoMVar cvar a c = mutMVar True cvar a (const c)
-- | Try to put into a @MVar@, not blocking if full. -- | Try to put into a @MVar@, not blocking if full.
tryPutIntoMVar :: MonadRef r n => MVar r a -> a -> (Bool -> Action n r s) tryPutIntoMVar :: MonadRef r n => MVar r a -> a -> (Bool -> Action n r)
-> ThreadId -> Threads n r s -> n (Bool, Threads n r s, [ThreadId]) -> ThreadId -> Threads n r -> n (Bool, Threads n r, [ThreadId])
tryPutIntoMVar = mutMVar False tryPutIntoMVar = mutMVar False
-- | Read from a @MVar@, blocking if empty. -- | Read from a @MVar@, blocking if empty.
readFromMVar :: MonadRef r n => MVar r a -> (a -> Action n r s) readFromMVar :: MonadRef r n => MVar r a -> (a -> Action n r)
-> ThreadId -> Threads n r s -> n (Bool, Threads n r s, [ThreadId]) -> ThreadId -> Threads n r -> n (Bool, Threads n r, [ThreadId])
readFromMVar cvar c = seeMVar False True cvar (c . fromJust) readFromMVar cvar c = seeMVar False True cvar (c . fromJust)
-- | Take from a @MVar@, blocking if empty. -- | Take from a @MVar@, blocking if empty.
takeFromMVar :: MonadRef r n => MVar r a -> (a -> Action n r s) takeFromMVar :: MonadRef r n => MVar r a -> (a -> Action n r)
-> ThreadId -> Threads n r s -> n (Bool, Threads n r s, [ThreadId]) -> ThreadId -> Threads n r -> n (Bool, Threads n r, [ThreadId])
takeFromMVar cvar c = seeMVar True True cvar (c . fromJust) takeFromMVar cvar c = seeMVar True True cvar (c . fromJust)
-- | Try to take from a @MVar@, not blocking if empty. -- | Try to take from a @MVar@, not blocking if empty.
tryTakeFromMVar :: MonadRef r n => MVar r a -> (Maybe a -> Action n r s) tryTakeFromMVar :: MonadRef r n => MVar r a -> (Maybe a -> Action n r)
-> ThreadId -> Threads n r s -> n (Bool, Threads n r s, [ThreadId]) -> ThreadId -> Threads n r -> n (Bool, Threads n r, [ThreadId])
tryTakeFromMVar = seeMVar True False tryTakeFromMVar = seeMVar True False
-- | Mutate a @MVar@, in either a blocking or nonblocking way. -- | Mutate a @MVar@, in either a blocking or nonblocking way.
mutMVar :: MonadRef r n mutMVar :: MonadRef r n
=> Bool -> MVar r a -> a -> (Bool -> Action n r s) => Bool -> MVar r a -> a -> (Bool -> Action n r)
-> ThreadId -> Threads n r s -> n (Bool, Threads n r s, [ThreadId]) -> ThreadId -> Threads n r -> n (Bool, Threads n r, [ThreadId])
mutMVar blocking (MVar cvid ref) a c threadid threads = do mutMVar blocking (MVar cvid ref) a c threadid threads = do
val <- readRef ref val <- readRef ref
@ -191,8 +191,8 @@ mutMVar blocking (MVar cvid ref) a c threadid threads = do
-- | Read a @MVar@, in either a blocking or nonblocking -- | Read a @MVar@, in either a blocking or nonblocking
-- way. -- way.
seeMVar :: MonadRef r n seeMVar :: MonadRef r n
=> Bool -> Bool -> MVar r a -> (Maybe a -> Action n r s) => Bool -> Bool -> MVar r a -> (Maybe a -> Action n r)
-> ThreadId -> Threads n r s -> n (Bool, Threads n r s, [ThreadId]) -> ThreadId -> Threads n r -> n (Bool, Threads n r, [ThreadId])
seeMVar emptying blocking (MVar cvid ref) c threadid threads = do seeMVar emptying blocking (MVar cvid ref) c threadid threads = do
val <- readRef ref val <- readRef ref

View File

@ -27,22 +27,22 @@ import qualified Data.Map.Strict as M
-- * Threads -- * Threads
-- | Threads are stored in a map index by 'ThreadId'. -- | Threads are stored in a map index by 'ThreadId'.
type Threads n r s = Map ThreadId (Thread n r s) type Threads n r = Map ThreadId (Thread n r)
-- | All the state of a thread. -- | All the state of a thread.
data Thread n r s = Thread data Thread n r = Thread
{ _continuation :: Action n r s { _continuation :: Action n r
-- ^ The next action to execute. -- ^ The next action to execute.
, _blocking :: Maybe BlockedOn , _blocking :: Maybe BlockedOn
-- ^ The state of any blocks. -- ^ The state of any blocks.
, _handlers :: [Handler n r s] , _handlers :: [Handler n r]
-- ^ Stack of exception handlers -- ^ Stack of exception handlers
, _masking :: MaskingState , _masking :: MaskingState
-- ^ The exception masking state. -- ^ The exception masking state.
} }
-- | Construct a thread with just one action -- | Construct a thread with just one action
mkthread :: Action n r s -> Thread n r s mkthread :: Action n r -> Thread n r
mkthread c = Thread c Nothing [] Unmasked mkthread c = Thread c Nothing [] Unmasked
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
@ -53,7 +53,7 @@ mkthread c = Thread c Nothing [] Unmasked
data BlockedOn = OnMVarFull MVarId | OnMVarEmpty MVarId | OnTVar [TVarId] | OnMask ThreadId deriving Eq data BlockedOn = OnMVarFull MVarId | OnMVarEmpty MVarId | OnTVar [TVarId] | OnMask ThreadId deriving Eq
-- | Determine if a thread is blocked in a certain way. -- | Determine if a thread is blocked in a certain way.
(~=) :: Thread n r s -> BlockedOn -> Bool (~=) :: Thread n r -> BlockedOn -> Bool
thread ~= theblock = case (_blocking thread, theblock) of thread ~= theblock = case (_blocking thread, theblock) of
(Just (OnMVarFull _), OnMVarFull _) -> True (Just (OnMVarFull _), OnMVarFull _) -> True
(Just (OnMVarEmpty _), OnMVarEmpty _) -> True (Just (OnMVarEmpty _), OnMVarEmpty _) -> True
@ -65,11 +65,11 @@ thread ~= theblock = case (_blocking thread, theblock) of
-- * Exceptions -- * Exceptions
-- | An exception handler. -- | An exception handler.
data Handler n r s = forall e. Exception e => Handler (e -> Action n r s) data Handler n r = forall e. Exception e => Handler (e -> Action n r)
-- | Propagate an exception upwards, finding the closest handler -- | Propagate an exception upwards, finding the closest handler
-- which can deal with it. -- which can deal with it.
propagate :: SomeException -> ThreadId -> Threads n r s -> Maybe (Threads n r s) propagate :: SomeException -> ThreadId -> Threads n r -> Maybe (Threads n r)
propagate e tid threads = case M.lookup tid threads >>= go . _handlers of propagate e tid threads = case M.lookup tid threads >>= go . _handlers of
Just (act, hs) -> Just $ except act hs tid threads Just (act, hs) -> Just $ except act hs tid threads
Nothing -> Nothing Nothing -> Nothing
@ -79,40 +79,40 @@ propagate e tid threads = case M.lookup tid threads >>= go . _handlers of
go (Handler h:hs) = maybe (go hs) (\act -> Just (act, hs)) $ h <$> fromException e go (Handler h:hs) = maybe (go hs) (\act -> Just (act, hs)) $ h <$> fromException e
-- | Check if a thread can be interrupted by an exception. -- | Check if a thread can be interrupted by an exception.
interruptible :: Thread n r s -> Bool interruptible :: Thread n r -> Bool
interruptible thread = _masking thread == Unmasked || (_masking thread == MaskedInterruptible && isJust (_blocking thread)) interruptible thread = _masking thread == Unmasked || (_masking thread == MaskedInterruptible && isJust (_blocking thread))
-- | Register a new exception handler. -- | Register a new exception handler.
catching :: Exception e => (e -> Action n r s) -> ThreadId -> Threads n r s -> Threads n r s catching :: Exception e => (e -> Action n r) -> ThreadId -> Threads n r -> Threads n r
catching h = M.alter $ \(Just thread) -> Just $ thread { _handlers = Handler h : _handlers thread } catching h = M.alter $ \(Just thread) -> Just $ thread { _handlers = Handler h : _handlers thread }
-- | Remove the most recent exception handler. -- | Remove the most recent exception handler.
uncatching :: ThreadId -> Threads n r s -> Threads n r s uncatching :: ThreadId -> Threads n r -> Threads n r
uncatching = M.alter $ \(Just thread) -> Just $ thread { _handlers = tail $ _handlers thread } uncatching = M.alter $ \(Just thread) -> Just $ thread { _handlers = tail $ _handlers thread }
-- | Raise an exception in a thread. -- | Raise an exception in a thread.
except :: Action n r s -> [Handler n r s] -> ThreadId -> Threads n r s -> Threads n r s except :: Action n r -> [Handler n r] -> ThreadId -> Threads n r -> Threads n r
except act hs = M.alter $ \(Just thread) -> Just $ thread { _continuation = act, _handlers = hs, _blocking = Nothing } except act hs = M.alter $ \(Just thread) -> Just $ thread { _continuation = act, _handlers = hs, _blocking = Nothing }
-- | Set the masking state of a thread. -- | Set the masking state of a thread.
mask :: MaskingState -> ThreadId -> Threads n r s -> Threads n r s mask :: MaskingState -> ThreadId -> Threads n r -> Threads n r
mask ms = M.alter $ \(Just thread) -> Just $ thread { _masking = ms } mask ms = M.alter $ \(Just thread) -> Just $ thread { _masking = ms }
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
-- * Manipulating threads -- * Manipulating threads
-- | Replace the @Action@ of a thread. -- | Replace the @Action@ of a thread.
goto :: Action n r s -> ThreadId -> Threads n r s -> Threads n r s goto :: Action n r -> ThreadId -> Threads n r -> Threads n r
goto a = M.alter $ \(Just thread) -> Just (thread { _continuation = a }) goto a = M.alter $ \(Just thread) -> Just (thread { _continuation = a })
-- | Start a thread with the given ID, inheriting the masking state -- | Start a thread with the given ID, inheriting the masking state
-- from the parent thread. This ID must not already be in use! -- from the parent thread. This ID must not already be in use!
launch :: ThreadId -> ThreadId -> ((forall b. M n r s b -> M n r s b) -> Action n r s) -> Threads n r s -> Threads n r s launch :: ThreadId -> ThreadId -> ((forall b. M n r b -> M n r b) -> Action n r) -> Threads n r -> Threads n r
launch parent tid a threads = launch' ms tid a threads where launch parent tid a threads = launch' ms tid a threads where
ms = fromMaybe Unmasked $ _masking <$> M.lookup parent threads ms = fromMaybe Unmasked $ _masking <$> M.lookup parent threads
-- | Start a thread with the given ID and masking state. This must not already be in use! -- | Start a thread with the given ID and masking state. This must not already be in use!
launch' :: MaskingState -> ThreadId -> ((forall b. M n r s b -> M n r s b) -> Action n r s) -> Threads n r s -> Threads n r s launch' :: MaskingState -> ThreadId -> ((forall b. M n r b -> M n r b) -> Action n r) -> Threads n r -> Threads n r
launch' ms tid a = M.insert tid thread where launch' ms tid a = M.insert tid thread where
thread = Thread { _continuation = a umask, _blocking = Nothing, _handlers = [], _masking = ms } thread = Thread { _continuation = a umask, _blocking = Nothing, _handlers = [], _masking = ms }
@ -120,11 +120,11 @@ launch' ms tid a = M.insert tid thread where
resetMask typ m = cont $ \k -> AResetMask typ True m $ k () resetMask typ m = cont $ \k -> AResetMask typ True m $ k ()
-- | Kill a thread. -- | Kill a thread.
kill :: ThreadId -> Threads n r s -> Threads n r s kill :: ThreadId -> Threads n r -> Threads n r
kill = M.delete kill = M.delete
-- | Block a thread. -- | Block a thread.
block :: BlockedOn -> ThreadId -> Threads n r s -> Threads n r s block :: BlockedOn -> ThreadId -> Threads n r -> Threads n r
block blockedOn = M.alter doBlock where block blockedOn = M.alter doBlock where
doBlock (Just thread) = Just $ thread { _blocking = Just blockedOn } doBlock (Just thread) = Just $ thread { _blocking = Just blockedOn }
doBlock _ = error "Invariant failure in 'block': thread does NOT exist!" doBlock _ = error "Invariant failure in 'block': thread does NOT exist!"
@ -132,7 +132,7 @@ block blockedOn = M.alter doBlock where
-- | Unblock all threads waiting on the appropriate block. For 'TVar' -- | Unblock all threads waiting on the appropriate block. For 'TVar'
-- blocks, this will wake all threads waiting on at least one of the -- blocks, this will wake all threads waiting on at least one of the
-- given 'TVar's. -- given 'TVar's.
wake :: BlockedOn -> Threads n r s -> (Threads n r s, [ThreadId]) wake :: BlockedOn -> Threads n r -> (Threads n r, [ThreadId])
wake blockedOn threads = (unblock <$> threads, M.keys $ M.filter isBlocked threads) where wake blockedOn threads = (unblock <$> threads, M.keys $ M.filter isBlocked threads) where
unblock thread unblock thread
| isBlocked thread = thread { _blocking = Nothing } | isBlocked thread = thread { _blocking = Nothing }