diff --git a/Control/Monad/Conc/Class.hs b/Control/Monad/Conc/Class.hs index 897864d..4abe730 100755 --- a/Control/Monad/Conc/Class.hs +++ b/Control/Monad/Conc/Class.hs @@ -1,4 +1,5 @@ {-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE TypeFamilies #-} -- | This module captures in a typeclass the interface of concurrency @@ -7,17 +8,19 @@ module Control.Monad.Conc.Class ( MonadConc(..) -- * Utilities , spawn + , forkFinally , killThread ) where import Control.Concurrent (forkIO) import Control.Concurrent.MVar (MVar, readMVar, newEmptyMVar, putMVar, tryPutMVar, takeMVar, tryTakeMVar) -import Control.Exception (Exception, AsyncException(ThreadKilled)) +import Control.Exception (Exception, AsyncException(ThreadKilled), SomeException) import Control.Monad (unless) -import Control.Monad.Catch (MonadCatch, MonadThrow, catch, throwM) +import Control.Monad.Catch (MonadCatch, MonadThrow, MonadMask) import Control.Monad.STM (STM) import Control.Monad.STM.Class (MonadSTM) +import qualified Control.Monad.Catch as Ca import qualified Control.Concurrent as C import qualified Control.Monad.STM as S @@ -44,7 +47,7 @@ import qualified Control.Monad.STM as S -- 'takeCVar' and 'putCVar', however, are very inefficient, and should -- probably always be overridden to make use of -- implementation-specific blocking functionality. -class ( Monad m, MonadCatch m, MonadThrow m +class ( Monad m, MonadCatch m, MonadThrow m, MonadMask m , MonadSTM (STMLike m) , Eq (ThreadId m), Show (ThreadId m)) => MonadConc m where -- | The associated 'MonadSTM' for this class. @@ -63,6 +66,10 @@ class ( Monad m, MonadCatch m, MonadThrow m -- happen over @CVar@s. fork :: m () -> m (ThreadId m) + -- | Like 'fork', but the child thread is passed a function that can + -- be used to unmask asynchronous exceptions. + forkWithUnmask :: ((forall a. m a -> m a) -> m ()) -> m (ThreadId m) + -- | Get the @ThreadId@ of the current thread. myThreadId :: m (ThreadId m) @@ -107,22 +114,52 @@ class ( Monad m, MonadCatch m, MonadThrow m -- exception handler capable of dealing with it and, if one is not -- found, the thread is killed. -- - -- > throw = throwM + -- > throw = Control.Monad.Catch.throwM throw :: Exception e => e -> m a - throw = throwM + throw = Ca.throwM -- | Catch an exception. This is only required to be able to catch -- exceptions raised by 'throw', unlike the more general -- Control.Exception.catch function. If you need to be able to catch -- /all/ errors, you will have to use 'IO'. + -- + -- > catch = Control.Monad.Catch.catch catch :: Exception e => m a -> (e -> m a) -> m a - catch = Control.Monad.Catch.catch + catch = Ca.catch -- | Throw an exception to the target thread. This blocks until the -- exception is delivered, and it is just as if the target thread -- had raised it with 'throw'. This can interrupt a blocked action. throwTo :: Exception e => ThreadId m => e -> m () + -- | Executes a computation with asynchronous exceptions + -- /masked/. That is, any thread which attempts to raise an + -- exception in the current thread with 'throwTo' will be blocked + -- until asynchronous exceptions are unmasked again. + -- + -- The argument passed to mask is a function that takes as its + -- argument another function, which can be used to restore the + -- prevailing masking state within the context of the masked + -- computation. + -- + -- > mask = Control.Monad.Catch.mask + mask :: ((forall a. m a -> m a) -> m b) -> m b + mask = Ca.mask + + -- | Like 'mask', but the masked computation is not + -- interruptible. THIS SHOULD BE USED WITH GREAT CARE, because if a + -- thread executing in 'uninterruptibleMask' blocks for any reason, + -- then the thread (and possibly the program, if this is the main + -- thread) will be unresponsive and unkillable. This function should + -- only be necessary if you need to mask exceptions around an + -- interruptible operation, and you can guarantee that the + -- interruptible operation will only block for a short period of + -- time. + -- + -- > uninterruptibleMask = Control.Monad.Catch.uninterruptibleMask + uninterruptibleMask :: ((forall a. m a -> m a) -> m b) -> m b + uninterruptibleMask = Ca.uninterruptibleMask + -- | Runs its argument, just as if the @_concNoTest@ weren't there. -- -- > _concNoTest x = x @@ -149,16 +186,17 @@ instance MonadConc IO where type CVar IO = MVar type ThreadId IO = C.ThreadId - readCVar = readMVar - fork = forkIO - myThreadId = C.myThreadId - throwTo = C.throwTo - newEmptyCVar = newEmptyMVar - putCVar = putMVar - tryPutCVar = tryPutMVar - takeCVar = takeMVar - tryTakeCVar = tryTakeMVar - atomically = S.atomically + readCVar = readMVar + fork = forkIO + forkWithUnmask = C.forkIOWithUnmask + myThreadId = C.myThreadId + throwTo = C.throwTo + newEmptyCVar = newEmptyMVar + putCVar = putMVar + tryPutCVar = tryPutMVar + takeCVar = takeMVar + tryTakeCVar = tryTakeMVar + atomically = S.atomically -- | Create a concurrent computation for the provided action, and -- return a @CVar@ which can be used to query the result. @@ -168,6 +206,17 @@ spawn ma = do _ <- fork $ ma >>= putCVar cvar return cvar +-- | Fork a thread and call the supplied function when the thread is +-- about to terminate, with an exception or a returned value. The +-- function is called with asynchronous exceptions masked. +-- +-- This function is useful for informing the parent when a child +-- terminates, for example. +forkFinally :: MonadConc m => m a -> (Either SomeException a -> m ()) -> m (ThreadId m) +forkFinally action and_then = + mask $ \restore -> + fork $ Ca.try (restore action) >>= and_then + -- | Raise the 'ThreadKilled' exception in the target thread. Note -- that if the thread is prepared to catch this exception, it won't -- actually kill it. diff --git a/Test/DejaFu/Deterministic.hs b/Test/DejaFu/Deterministic.hs index 6a4a07b..fbcce3f 100755 --- a/Test/DejaFu/Deterministic.hs +++ b/Test/DejaFu/Deterministic.hs @@ -1,5 +1,5 @@ {-# LANGUAGE GeneralizedNewtypeDeriving #-} -{-# LANGUAGE Rank2Types #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE TypeFamilies #-} -- | Deterministic traced execution of concurrent computations which @@ -14,6 +14,8 @@ module Test.DejaFu.Deterministic , Failure(..) , runConc , fork + , forkFinally + , forkWithUnmask , myThreadId , spawn , atomically @@ -21,6 +23,8 @@ module Test.DejaFu.Deterministic , throwTo , killThread , catch + , mask + , uninterruptibleMask -- * Communication: CVars , CVar @@ -70,22 +74,27 @@ instance Ca.MonadCatch (Conc t) where instance Ca.MonadThrow (Conc t) where throwM = throw +instance Ca.MonadMask (Conc t) where + mask = mask + uninterruptibleMask = uninterruptibleMask + instance C.MonadConc (Conc t) where type CVar (Conc t) = CVar t type STMLike (Conc t) = STMLike t (ST t) (STRef t) type ThreadId (Conc t) = Int - fork = fork - myThreadId = myThreadId - throwTo = throwTo - newEmptyCVar = newEmptyCVar - putCVar = putCVar - tryPutCVar = tryPutCVar - readCVar = readCVar - takeCVar = takeCVar - tryTakeCVar = tryTakeCVar - atomically = atomically - _concNoTest = _concNoTest + fork = fork + forkWithUnmask = forkWithUnmask + myThreadId = myThreadId + throwTo = throwTo + newEmptyCVar = newEmptyCVar + putCVar = putCVar + tryPutCVar = tryPutCVar + readCVar = readCVar + takeCVar = takeCVar + tryTakeCVar = tryTakeCVar + atomically = atomically + _concNoTest = _concNoTest fixed :: Fixed (ST t) (STRef t) (STMLike t) fixed = Wrapper refST $ \ma -> cont (\c -> ALift $ c <$> ma) @@ -170,6 +179,45 @@ killThread = C.killThread catch :: Exception e => Conc t a -> (e -> Conc t a) -> Conc t a catch ma h = C $ cont $ ACatching (unC . h) (unC ma) +-- | Fork a thread and call the supplied function when the thread is +-- about to terminate, with an exception or a returned value. The +-- function is called with asynchronous exceptions masked. +-- +-- This function is useful for informing the parent when a child +-- terminates, for example. +forkFinally :: Conc t a -> (Either SomeException a -> Conc t ()) -> Conc t ThreadId +forkFinally action and_then = + mask $ \restore -> + fork $ Ca.try (restore action) >>= and_then + +-- | Like 'fork', but the child thread is passed a function that can +-- be used to unmask asynchronous exceptions. +forkWithUnmask :: ((forall a. Conc t a -> Conc t a) -> Conc t ()) -> Conc t ThreadId +forkWithUnmask = error "'forkWithUnmask' not yet implemented for 'Conc'" + +-- | Executes a computation with asynchronous exceptions +-- /masked/. That is, any thread which attempts to raise an exception +-- in the current thread with 'throwTo' will be blocked until +-- asynchronous exceptions are unmasked again. +-- +-- The argument passed to mask is a function that takes as its +-- argument another function, which can be used to restore the +-- prevailing masking state within the context of the masked +-- computation. +mask :: ((forall a. Conc t a -> Conc t a) -> Conc t b) -> Conc t b +mask = error "'mask' not yet implemented for 'Conc'" + +-- | Like 'mask', but the masked computation is not +-- interruptible. THIS SHOULD BE USED WITH GREAT CARE, because if a +-- thread executing in 'uninterruptibleMask' blocks for any reason, +-- then the thread (and possibly the program, if this is the main +-- thread) will be unresponsive and unkillable. This function should +-- only be necessary if you need to mask exceptions around an +-- interruptible operation, and you can guarantee that the +-- interruptible operation will only block for a short period of time. +uninterruptibleMask :: ((forall a. Conc t a -> Conc t a) -> Conc t b) -> Conc t b +uninterruptibleMask = error "'uninterruptibleMask' not yet implemented for 'Conc'" + -- | Run the argument in one step. If the argument fails, the whole -- computation will fail. _concNoTest :: Conc t a -> Conc t a diff --git a/Test/DejaFu/Deterministic/IO.hs b/Test/DejaFu/Deterministic/IO.hs index cb26207..33a947f 100644 --- a/Test/DejaFu/Deterministic/IO.hs +++ b/Test/DejaFu/Deterministic/IO.hs @@ -1,5 +1,5 @@ {-# LANGUAGE GeneralizedNewtypeDeriving #-} -{-# LANGUAGE Rank2Types #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE TypeFamilies #-} -- | Deterministic traced execution of concurrent computations which @@ -18,6 +18,8 @@ module Test.DejaFu.Deterministic.IO , runConcIO , liftIO , fork + , forkFinally + , forkWithUnmask , myThreadId , spawn , atomically @@ -25,6 +27,8 @@ module Test.DejaFu.Deterministic.IO , throwTo , killThread , catch + , mask + , uninterruptibleMask -- * Communication: CVars , CVar @@ -71,6 +75,10 @@ instance Ca.MonadCatch (ConcIO t) where instance Ca.MonadThrow (ConcIO t) where throwM = throw +instance Ca.MonadMask (ConcIO t) where + mask = mask + uninterruptibleMask = uninterruptibleMask + instance IO.MonadIO (ConcIO t) where liftIO = liftIO @@ -79,17 +87,18 @@ instance C.MonadConc (ConcIO t) where type STMLike (ConcIO t) = STMLike t IO IORef type ThreadId (ConcIO t) = Int - fork = fork - myThreadId = myThreadId - throwTo = throwTo - newEmptyCVar = newEmptyCVar - putCVar = putCVar - tryPutCVar = tryPutCVar - readCVar = readCVar - takeCVar = takeCVar - tryTakeCVar = tryTakeCVar - atomically = atomically - _concNoTest = _concNoTest + fork = fork + forkWithUnmask = forkWithUnmask + myThreadId = myThreadId + throwTo = throwTo + newEmptyCVar = newEmptyCVar + putCVar = putCVar + tryPutCVar = tryPutCVar + readCVar = readCVar + takeCVar = takeCVar + tryTakeCVar = tryTakeCVar + atomically = atomically + _concNoTest = _concNoTest fixed :: Fixed IO IORef (STMLike t) fixed = Wrapper refIO $ unC . liftIO @@ -175,6 +184,45 @@ killThread = C.killThread catch :: Exception e => ConcIO t a -> (e -> ConcIO t a) -> ConcIO t a catch ma h = C $ cont $ ACatching (unC . h) (unC ma) +-- | Fork a thread and call the supplied function when the thread is +-- about to terminate, with an exception or a returned value. The +-- function is called with asynchronous exceptions masked. +-- +-- This function is useful for informing the parent when a child +-- terminates, for example. +forkFinally :: ConcIO t a -> (Either SomeException a -> ConcIO t ()) -> ConcIO t ThreadId +forkFinally action and_then = + mask $ \restore -> + fork $ Ca.try (restore action) >>= and_then + +-- | Like 'fork', but the child thread is passed a function that can +-- be used to unmask asynchronous exceptions. +forkWithUnmask :: ((forall a. ConcIO t a -> ConcIO t a) -> ConcIO t ()) -> ConcIO t ThreadId +forkWithUnmask = error "'forkWithUnmask' not yet implemented for 'ConcIO'" + +-- | Executes a computation with asynchronous exceptions +-- /masked/. That is, any thread which attempts to raise an exception +-- in the current thread with 'throwTo' will be blocked until +-- asynchronous exceptions are unmasked again. +-- +-- The argument passed to mask is a function that takes as its +-- argument another function, which can be used to restore the +-- prevailing masking state within the context of the masked +-- computation. +mask :: ((forall a. ConcIO t a -> ConcIO t a) -> ConcIO t b) -> ConcIO t b +mask = error "'mask' not yet implemented for 'ConcIO'" + +-- | Like 'mask', but the masked computation is not +-- interruptible. THIS SHOULD BE USED WITH GREAT CARE, because if a +-- thread executing in 'uninterruptibleMask' blocks for any reason, +-- then the thread (and possibly the program, if this is the main +-- thread) will be unresponsive and unkillable. This function should +-- only be necessary if you need to mask exceptions around an +-- interruptible operation, and you can guarantee that the +-- interruptible operation will only block for a short period of time. +uninterruptibleMask :: ((forall a. ConcIO t a -> ConcIO t a) -> ConcIO t b) -> ConcIO t b +uninterruptibleMask = error "'uninterruptibleMask' not yet implemented for 'ConcIO'" + -- | Run the argument in one step. If the argument fails, the whole -- computation will fail. _concNoTest :: ConcIO t a -> ConcIO t a