diff --git a/dejafu-tests/Cases/MultiThreaded.hs b/dejafu-tests/Cases/MultiThreaded.hs index e84b940..43e602b 100644 --- a/dejafu-tests/Cases/MultiThreaded.hs +++ b/dejafu-tests/Cases/MultiThreaded.hs @@ -11,6 +11,7 @@ import Test.HUnit.DejaFu (testDejafu) import Control.Concurrent.Classy import Control.Monad.STM.Class +import Test.DejaFu.Conc (Conc, subconcurrency) #if __GLASGOW_HASKELL__ < 710 import Control.Applicative ((<$>), (<*>)) @@ -49,6 +50,13 @@ tests = , testGroup "Daemons" . hUnitTestToTests $ test [ testDejafu schedDaemon "schedule daemon" $ gives' [0,1] ] + + , testGroup "Subconcurrency" . hUnitTestToTests $ test + [ testDejafu scDeadlock1 "deadlock1" $ gives' [Left Deadlock, Right ()] + , testDejafu scDeadlock2 "deadlock2" $ gives' [(Left Deadlock, ()), (Right (), ())] + , testDejafu scSuccess "success" $ gives' [Right ()] + , testDejafu scIllegal "illegal" $ gives [Left IllegalSubconcurrency] + ] ] -------------------------------------------------------------------------------- @@ -207,3 +215,40 @@ schedDaemon = do x <- newCRef 0 _ <- fork $ myThreadId >> writeCRef x 1 readCRef x + +-------------------------------------------------------------------------------- +-- Subconcurrency + +-- | Subcomputation deadlocks sometimes. +scDeadlock1 :: Monad n => Conc n r (Either Failure ()) +scDeadlock1 = do + var <- newEmptyMVar + subconcurrency $ do + void . fork $ putMVar var () + putMVar var () + +-- | Subcomputation deadlocks sometimes, and action after it still +-- happens. +scDeadlock2 :: Monad n => Conc n r (Either Failure (), ()) +scDeadlock2 = do + var <- newEmptyMVar + res <- subconcurrency $ do + void . fork $ putMVar var () + putMVar var () + (,) <$> pure res <*> readMVar var + +-- | Subcomputation successfully completes. +scSuccess :: Monad n => Conc n r (Either Failure ()) +scSuccess = do + var <- newMVar () + subconcurrency $ do + out <- newEmptyMVar + void . fork $ takeMVar var >>= putMVar out + takeMVar out + +-- | Illegal usage +scIllegal :: Monad n => Conc n r () +scIllegal = do + var <- newEmptyMVar + void . fork $ readMVar var + void . subconcurrency $ pure () diff --git a/dejafu-tests/Cases/SingleThreaded.hs b/dejafu-tests/Cases/SingleThreaded.hs index 8ea54ee..698f0d7 100644 --- a/dejafu-tests/Cases/SingleThreaded.hs +++ b/dejafu-tests/Cases/SingleThreaded.hs @@ -3,6 +3,7 @@ module Cases.SingleThreaded where import Control.Exception (ArithException(..), ArrayException(..)) +import Control.Monad (void) import Test.DejaFu (Failure(..), gives, gives') import Test.Framework (Test, testGroup) import Test.Framework.Providers.HUnit (hUnitTestToTests) @@ -10,7 +11,7 @@ import Test.HUnit (test) import Test.HUnit.DejaFu (testDejafu) import Control.Concurrent.Classy -import Control.Monad.STM.Class +import Test.DejaFu.Conc (Conc, subconcurrency) import Utils @@ -58,6 +59,12 @@ tests = [ testDejafu capsGet "get" $ gives' [True] , testDejafu capsSet "set" $ gives' [True] ] + + , testGroup "Subconcurrency" . hUnitTestToTests $ test + [ testDejafu scDeadlock1 "deadlock1" $ gives' [Left Deadlock] + , testDejafu scDeadlock2 "deadlock2" $ gives' [(Left Deadlock, ())] + , testDejafu scSuccess "success" $ gives' [Right ()] + ] ] -------------------------------------------------------------------------------- @@ -252,3 +259,22 @@ capsSet = do caps <- getNumCapabilities setNumCapabilities $ caps + 1 (== caps + 1) <$> getNumCapabilities + +-------------------------------------------------------------------------------- +-- Subconcurrency + +-- | Subcomputation deadlocks. +scDeadlock1 :: Monad n => Conc n r (Either Failure ()) +scDeadlock1 = subconcurrency (newEmptyMVar >>= readMVar) + +-- | Subcomputation deadlocks, and action after it still happens. +scDeadlock2 :: Monad n => Conc n r (Either Failure (), ()) +scDeadlock2 = do + var <- newMVar () + (,) <$> subconcurrency (putMVar var ()) <*> readMVar var + +-- | Subcomputation successfully completes. +scSuccess :: Monad n => Conc n r (Either Failure ()) +scSuccess = do + var <- newMVar () + subconcurrency (takeMVar var) diff --git a/dejafu/Test/DejaFu/Conc.hs b/dejafu/Test/DejaFu/Conc.hs index 4e567b0..90b50a8 100755 --- a/dejafu/Test/DejaFu/Conc.hs +++ b/dejafu/Test/DejaFu/Conc.hs @@ -50,12 +50,10 @@ import Control.Exception (MaskingState(..)) import qualified Control.Monad.Base as Ba import qualified Control.Monad.Catch as Ca import qualified Control.Monad.IO.Class as IO -import Control.Monad.Ref (MonadRef, newRef, readRef, writeRef) +import Control.Monad.Ref (MonadRef,) import Control.Monad.ST (ST) import Data.Dynamic (toDyn) import Data.IORef (IORef) -import qualified Data.Map.Strict as M -import Data.Maybe (fromJust) import Data.STRef (STRef) import Test.DejaFu.Schedule @@ -63,7 +61,6 @@ import qualified Control.Monad.Conc.Class as C import Test.DejaFu.Common import Test.DejaFu.Conc.Internal import Test.DejaFu.Conc.Internal.Common -import Test.DejaFu.Conc.Internal.Threading import Test.DejaFu.STM {-# ANN module ("HLint: ignore Avoid lambda" :: String) #-} @@ -182,23 +179,9 @@ runConcurrent :: MonadRef r n -> s -> Conc n r a -> n (Either Failure a, s, Trace) -runConcurrent sched memtype s (C conc) = do - ref <- newRef Nothing - - let c = runCont conc (AStop . writeRef ref . Just . Right) - let threads = launch' Unmasked initialThread (const c) M.empty - - (s', trace) <- runThreads runTransaction - sched - memtype - s - threads - initialIdSource - ref - - out <- readRef ref - - pure (fromJust out, s', reverse trace) +runConcurrent sched memtype s ma = do + (res, s', trace) <- runConcurrency runTransaction sched memtype s (unC ma) + pure (res, s', reverse trace) -- | Run a concurrent computation and return its result. -- diff --git a/dejafu/Test/DejaFu/Conc/Internal.hs b/dejafu/Test/DejaFu/Conc/Internal.hs index 2b8b8ed..51edac9 100755 --- a/dejafu/Test/DejaFu/Conc/Internal.hs +++ b/dejafu/Test/DejaFu/Conc/Internal.hs @@ -16,12 +16,12 @@ module Test.DejaFu.Conc.Internal where import Control.Exception (MaskingState(..), toException) -import Control.Monad.Ref (MonadRef, newRef, writeRef) +import Control.Monad.Ref (MonadRef, newRef, readRef, writeRef) import Data.Functor (void) import Data.List (sort) import Data.List.NonEmpty (NonEmpty(..), fromList) import qualified Data.Map.Strict as M -import Data.Maybe (fromJust, isJust, isNothing, listToMaybe) +import Data.Maybe (fromJust, fromMaybe, isJust, isNothing, listToMaybe) import Test.DejaFu.Common import Test.DejaFu.Conc.Internal.Common @@ -36,6 +36,35 @@ import Test.DejaFu.STM (Result(..)) -------------------------------------------------------------------------------- -- * Execution +-- | Run a concurrent computation with a given 'Scheduler' and initial +-- state, returning a failure reason on error. Also returned is the +-- final state of the scheduler, and an execution trace (in reverse +-- order). +runConcurrency :: MonadRef r n + => (forall x. s x -> IdSource -> n (Result x, IdSource, TTrace)) + -> Scheduler g + -> MemType + -> g + -> M n r s a + -> n (Either Failure a, g, Trace) +runConcurrency runstm sched memtype g ma = do + ref <- newRef Nothing + + let c = runCont ma (AStop . writeRef ref . Just . Right) + let threads = launch' Unmasked initialThread (const c) M.empty + + (g', trace) <- runThreads runstm + sched + memtype + g + threads + initialIdSource + ref + + out <- readRef ref + + pure (fromJust out, g', trace) + -- | Run a collection of threads, until there are no threads left. -- -- Note: this returns the trace in reverse order, because it's more @@ -53,13 +82,13 @@ runThreads runstm sched memtype origg origthreads idsrc ref = go idsrc [] Nothin | isNonexistant = die g' InternalError | isBlocked = die g' InternalError | otherwise = do - stepped <- stepThread runstm memtype (_continuation $ fromJust thread) idSource chosen threads wb caps + stepped <- stepThread runstm sched memtype g (_continuation $ fromJust thread) idSource chosen threads wb caps case stepped of - Right (threads', idSource', act, wb', caps') -> loop threads' idSource' act wb' caps' + Right (threads', idSource', act, wb', caps', mg') -> loop threads' idSource' act (fromMaybe g' mg') wb' caps' Left UncaughtException | chosen == initialThread -> die g' UncaughtException - | otherwise -> loop (kill chosen threads) idSource Killed wb caps + | otherwise -> loop (kill chosen threads) idSource (Right Killed) g' wb caps Left failure -> die g' failure @@ -96,21 +125,28 @@ runThreads runstm sched memtype origg origthreads idsrc ref = go idsrc [] Nothin stop outg = pure (outg, sofar) die outg reason = writeRef ref (Just $ Left reason) >> stop outg - loop threads' idSource' act wb' = - let sofar' = ((decision, runnable', act) : sofar) + loop threads' idSource' trcOrAct g'' = + let trc = case trcOrAct of + Left (act, acts) -> (decision, runnable', act) : acts + Right act -> [(decision, runnable', act)] + sofar' = trc++sofar threads'' = if (interruptible <$> M.lookup chosen threads') /= Just False then unblockWaitingOn chosen threads' else threads' - in go idSource' sofar' (Just chosen) g' (delCommitThreads threads'') wb' + in go idSource' sofar' (Just chosen) g'' (delCommitThreads threads'') -------------------------------------------------------------------------------- -- * Single-step execution -- | Run a single thread one step, by dispatching on the type of -- 'Action'. -stepThread :: forall n r s. MonadRef r n +stepThread :: forall n r s g. MonadRef r n => (forall x. s x -> IdSource -> n (Result x, IdSource, TTrace)) -- ^ Run a 'MonadSTM' transaction atomically. + -> Scheduler g + -- ^ The scheduler. -> MemType -- ^ The memory model + -> g + -- ^ The scheduler state. -> Action n r s -- ^ Action to step -> IdSource @@ -123,8 +159,8 @@ stepThread :: forall n r s. MonadRef r n -- ^ @CRef@ write buffer -> Int -- ^ The number of capabilities - -> n (Either Failure (Threads n r s, IdSource, ThreadAction, WriteBuffer r, Int)) -stepThread runstm memtype action idSource tid threads wb caps = case action of + -> n (Either Failure (Threads n r s, IdSource, Either (ThreadAction, Trace) ThreadAction, WriteBuffer r, Int, Maybe g)) +stepThread runstm sched memtype g action idSource tid threads wb caps = case action of AFork n a b -> stepFork n a b AMyTId c -> stepMyTId c AGetNumCapabilities c -> stepGetNumCapabilities c @@ -165,8 +201,8 @@ stepThread runstm memtype action idSource tid threads wb caps = case action of stepFork :: String -> ((forall b. M n r s b -> M n r s b) -> Action n r s) -> (ThreadId -> Action n r s) - -> n (Either Failure (Threads n r s, IdSource, ThreadAction, WriteBuffer r, Int)) - stepFork n a b = return $ Right (goto (b newtid) tid threads', idSource', Fork newtid, wb, caps) where + -> n (Either Failure (Threads n r s, IdSource, Either z ThreadAction, WriteBuffer r, Int, Maybe g)) + stepFork n a b = return $ Right (goto (b newtid) tid threads', idSource', Right (Fork newtid), wb, caps, Nothing) where threads' = launch tid newtid a threads (idSource', newtid) = nextTId n idSource @@ -177,7 +213,7 @@ stepThread runstm memtype action idSource tid threads wb caps = case action of stepGetNumCapabilities c = simple (goto (c caps) tid threads) $ GetNumCapabilities caps -- | Set the number of capabilities - stepSetNumCapabilities i c = return $ Right (goto c tid threads, idSource, SetNumCapabilities i, wb, i) + stepSetNumCapabilities i c = return $ Right (goto c tid threads, idSource, Right (SetNumCapabilities i), wb, i, Nothing) -- | Yield the current thread stepYield c = simple (goto c tid threads) Yield @@ -243,12 +279,12 @@ stepThread runstm memtype action idSource tid threads wb caps = case action of -- Add to buffer using thread id. TotalStoreOrder -> do wb' <- bufferWrite wb (tid, Nothing) cref a - return $ Right (goto c tid threads, idSource, WriteRef crid, wb', caps) + return $ Right (goto c tid threads, idSource, Right (WriteRef crid), wb', caps, Nothing) -- Add to buffer using both thread id and cref id PartialStoreOrder -> do wb' <- bufferWrite wb (tid, Just crid) cref a - return $ Right (goto c tid threads, idSource, WriteRef crid, wb', caps) + return $ Right (goto c tid threads, idSource, Right (WriteRef crid), wb', caps, Nothing) -- | Perform a compare-and-swap on a @CRef@. stepCasRef cref@(CRef crid _) tick a c = synchronised $ do @@ -268,7 +304,7 @@ stepThread runstm memtype action idSource tid threads wb caps = case action of -- Commit using the cref id. PartialStoreOrder -> commitWrite wb (t, Just c) - return $ Right (threads, idSource, CommitRef t c, wb', caps) + return $ Right (threads, idSource, Right (CommitRef t c), wb', caps, Nothing) -- | Run a STM transaction atomically. stepAtom stm c = synchronised $ do @@ -276,14 +312,14 @@ stepThread runstm memtype action idSource tid threads wb caps = case action of case res of Success _ written val -> let (threads', woken) = wake (OnTVar written) threads - in return $ Right (goto (c val) tid threads', idSource', STM trace woken, wb, caps) + in return $ Right (goto (c val) tid threads', idSource', Right (STM trace woken), wb, caps, Nothing) Retry touched -> let threads' = block (OnTVar touched) tid threads - in return $ Right (threads', idSource', BlockedSTM trace, wb, caps) + in return $ Right (threads', idSource', Right (BlockedSTM trace), wb, caps, Nothing) Exception e -> do res' <- stepThrow e return $ case res' of - Right (threads', _, _, _, _) -> Right (threads', idSource', Throw, wb, caps) + Right (threads', _, _, _, _, _) -> Right (threads', idSource', Right Throw, wb, caps, Nothing) Left err -> Left err -- | Run a subcomputation in an exception-catching context. @@ -328,7 +364,7 @@ stepThread runstm memtype action idSource tid threads wb caps = case action of stepMasking :: MaskingState -> ((forall b. M n r s b -> M n r s b) -> M n r s a) -> (a -> Action n r s) - -> n (Either Failure (Threads n r s, IdSource, ThreadAction, WriteBuffer r, Int)) + -> n (Either Failure (Threads n r s, IdSource, Either z ThreadAction, WriteBuffer r, Int, Maybe g)) stepMasking m ma c = simple threads' $ SetMasking False m where a = runCont (ma umask) (AResetMask False False m' . c) @@ -348,14 +384,14 @@ stepThread runstm memtype action idSource tid threads wb caps = case action of let (idSource', newmvid) = nextMVId n idSource ref <- newRef Nothing let mvar = MVar newmvid ref - return $ Right (goto (c mvar) tid threads, idSource', NewVar newmvid, wb, caps) + return $ Right (goto (c mvar) tid threads, idSource', Right (NewVar newmvid), wb, caps, Nothing) -- | Create a new @CRef@, using the next 'CRefId'. stepNewRef n a c = do let (idSource', newcrid) = nextCRId n idSource ref <- newRef (M.empty, 0, a) let cref = CRef newcrid ref - return $ Right (goto (c cref) tid threads, idSource', NewRef newcrid, wb, caps) + return $ Right (goto (c cref) tid threads, idSource', Right (NewRef newcrid), wb, caps, Nothing) -- | Lift an action from the underlying monad into the @Conc@ -- computation. @@ -373,15 +409,16 @@ stepThread runstm memtype action idSource tid threads wb caps = case action of stepStop na = na >> simple (kill tid threads) Stop -- | Run a subconcurrent computation. - stepSubconcurrency ma k + stepSubconcurrency ma c | tid /= initialThread = return (Left IllegalSubconcurrency) | M.size threads > 1 = return (Left IllegalSubconcurrency) - -- todo: this case! - | otherwise = return (Left IllegalSubconcurrency) + | otherwise = do + (res, g', trace) <- runConcurrency runstm sched memtype g ma + return $ Right (goto (c res) tid threads, idSource, Left (Subconcurrency, trace), wb, caps, Just g') -- | Helper for actions which don't touch the 'IdSource' or -- 'WriteBuffer' - simple threads' act = return $ Right (threads', idSource, act, wb, caps) + simple threads' act = return $ Right (threads', idSource, Right act, wb, caps, Nothing) -- | Helper for actions impose a write barrier. synchronised ma = do @@ -389,5 +426,5 @@ stepThread runstm memtype action idSource tid threads wb caps = case action of res <- ma return $ case res of - Right (threads', idSource', act', _, caps') -> Right (threads', idSource', act', emptyBuffer, caps') + Right (threads', idSource', act', _, caps', g') -> Right (threads', idSource', act', emptyBuffer, caps', g') _ -> res