From 9b5e010d90ce6b256c66b3985ff54cf4ded15cc2 Mon Sep 17 00:00:00 2001 From: Michael Walker Date: Mon, 9 Feb 2015 22:04:28 +0000 Subject: [PATCH] Implement a STM runner. --- Control/State.hs | 37 ++++ Test/DejaFu/Deterministic.hs | 10 +- Test/DejaFu/Deterministic/IO.hs | 10 +- Test/DejaFu/Deterministic/Internal.hs | 68 ++++---- Test/DejaFu/STM.hs | 233 ++++++++++++++++++++++++++ dejafu.cabal | 2 + 6 files changed, 308 insertions(+), 52 deletions(-) create mode 100755 Control/State.hs create mode 100755 Test/DejaFu/STM.hs diff --git a/Control/State.hs b/Control/State.hs new file mode 100755 index 0000000..4827171 --- /dev/null +++ b/Control/State.hs @@ -0,0 +1,37 @@ +{-# LANGUAGE Rank2Types #-} + +-- | Dealing with mutable state. +module Control.State where + +import Control.Monad.ST (ST) +import Data.IORef (IORef, newIORef, readIORef, writeIORef) +import Data.STRef (STRef, newSTRef, readSTRef, writeSTRef) + +-- | Mutable references. +data Ref n r = Ref + { newRef :: forall a. a -> n (r a) + , readRef :: forall a. r a -> n a + , writeRef :: forall a. r a -> a -> n () + } + +-- | Method dict for 'ST'. +refST :: Ref (ST t) (STRef t) +refST = Ref + { newRef = newSTRef + , readRef = readSTRef + , writeRef = writeSTRef + } + +-- | Method dict for 'IO'. +refIO :: Ref IO IORef +refIO = Ref + { newRef = newIORef + , readRef = readIORef + , writeRef = writeIORef + } + +-- | Wrapped mutable references. +data Wrapper n r m = Wrapper + { wref :: Ref n r + , liftN :: forall a. n a -> m a + } diff --git a/Test/DejaFu/Deterministic.hs b/Test/DejaFu/Deterministic.hs index df66675..c7b951c 100755 --- a/Test/DejaFu/Deterministic.hs +++ b/Test/DejaFu/Deterministic.hs @@ -42,7 +42,8 @@ module Test.DejaFu.Deterministic import Control.Applicative (Applicative(..), (<$>)) import Control.Monad.Cont (cont, runCont) import Control.Monad.ST (ST, runST) -import Data.STRef (STRef, newSTRef, readSTRef, writeSTRef) +import Control.State (Wrapper(..), refST) +import Data.STRef (STRef, newSTRef) import Test.DejaFu.Deterministic.Internal import Test.DejaFu.Deterministic.Schedule @@ -67,12 +68,7 @@ instance C.MonadConc (Conc t) where _concNoTest = _concNoTest fixed :: Fixed (ST t) (STRef t) -fixed = F - { newRef = newSTRef - , readRef = readSTRef - , writeRef = writeSTRef - , liftN = \ma -> cont (\c -> ALift $ c <$> ma) - } +fixed = Wrapper refST $ \ma -> cont (\c -> ALift $ c <$> ma) -- | The concurrent variable type used with the 'Conc' monad. One -- notable difference between these and 'MVar's is that 'MVar's are diff --git a/Test/DejaFu/Deterministic/IO.hs b/Test/DejaFu/Deterministic/IO.hs index 919d3b2..158b0e7 100644 --- a/Test/DejaFu/Deterministic/IO.hs +++ b/Test/DejaFu/Deterministic/IO.hs @@ -45,7 +45,8 @@ module Test.DejaFu.Deterministic.IO import Control.Applicative (Applicative(..), (<$>)) import Control.Monad.Cont (cont, runCont) -import Data.IORef (IORef, newIORef, readIORef, writeIORef) +import Control.State (Wrapper(..), refIO) +import Data.IORef (IORef, newIORef) import Test.DejaFu.Deterministic.Internal import Test.DejaFu.Deterministic.Schedule @@ -71,12 +72,7 @@ instance C.MonadConc (ConcIO t) where _concNoTest = _concNoTest fixed :: Fixed IO IORef -fixed = F - { newRef = newIORef - , readRef = readIORef - , writeRef = writeIORef - , liftN = unC . liftIO - } +fixed = Wrapper refIO $ unC . liftIO -- | The concurrent variable type used with the 'ConcIO' monad. These -- behave the same as @Conc@'s @CVar@s diff --git a/Test/DejaFu/Deterministic/Internal.hs b/Test/DejaFu/Deterministic/Internal.hs index fc23f20..d358810 100755 --- a/Test/DejaFu/Deterministic/Internal.hs +++ b/Test/DejaFu/Deterministic/Internal.hs @@ -8,6 +8,7 @@ module Test.DejaFu.Deterministic.Internal where import Control.DeepSeq (NFData(..)) import Control.Monad (liftM, mapAndUnzipM) import Control.Monad.Cont (Cont, runCont) +import Control.State import Data.List.Extra import Data.Map (Map) import Data.Maybe (catMaybes, fromJust, isNothing) @@ -23,17 +24,8 @@ type M n r a = Cont (Action n r) a -- list of things blocked on it, and a unique numeric identifier. type R r a = r (CVarId, Maybe a, [Block]) --- | Dict of methods for concrete implementations to override. -data Fixed n r = F - { newRef :: forall a. a -> n (r a) - -- ^ Create a new reference - , readRef :: forall a. r a -> n a - -- ^ Read a reference. - , writeRef :: forall a. r a -> a -> n () - -- ^ Overwrite the contents of a reference. - , liftN :: forall a. n a -> M n r a - -- ^ Lift an action from the underlying monad - } +-- | Dict of methods for implementations to override. +type Fixed n r = Wrapper n r (Cont (Action n r)) -- * Running @Conc@ Computations @@ -170,15 +162,15 @@ instance NFData Failure where -- deadlock is detected. Also returned is the final state of the -- scheduler, and an execution trace. runFixed :: Monad n => Fixed n r - -> Scheduler s -> s -> M n r a -> n (Either Failure a, s, Trace) + -> Scheduler s -> s -> M n r a -> n (Either Failure a, s, Trace) runFixed fixed sched s ma = do - ref <- newRef fixed Nothing + ref <- newRef (wref fixed) Nothing - let c = ma >>= liftN fixed . writeRef fixed ref . Just . Right + let c = ma >>= liftN fixed . writeRef (wref fixed) ref . Just . Right let threads = M.fromList [(0, (runCont c $ const AStop, False))] (s', trace) <- runThreads fixed (-1, 0) [] (negate 1) sched s threads ref - out <- readRef fixed ref + out <- readRef (wref fixed) ref return (fromJust out, s', reverse trace) @@ -203,11 +195,11 @@ runThreads :: Monad n => Fixed n r -> (CVarId, ThreadId) -> Trace -> ThreadId -> Scheduler s -> s -> Threads n r -> r (Maybe (Either Failure a)) -> n (s, Trace) runThreads fixed (lastcvid, lasttid) sofar prior sched s threads ref | isTerminated = return (s, sofar) - | isDeadlocked = writeRef fixed ref (Just $ Left Deadlock) >> return (s, sofar) - | isNonexistant = writeRef fixed ref (Just $ Left InternalError) >> return (s, sofar) - | isBlocked = writeRef fixed ref (Just $ Left InternalError) >> return (s, sofar) + | isDeadlocked = writeRef (wref fixed) ref (Just $ Left Deadlock) >> return (s, sofar) + | isNonexistant = writeRef (wref fixed) ref (Just $ Left InternalError) >> return (s, sofar) + | isBlocked = writeRef (wref fixed) ref (Just $ Left InternalError) >> return (s, sofar) | otherwise = do - stepped <- stepThread (fst $ fromJust thread) fixed (sched, s) (lastcvid, lasttid) chosen threads + stepped <- stepThread fixed (fst $ fromJust thread) (sched, s) (lastcvid, lasttid) chosen threads case stepped of Right (threads', act) -> do let sofar' = (decision, alternatives, act) : sofar @@ -217,7 +209,7 @@ runThreads fixed (lastcvid, lasttid) sofar prior sched s threads ref runThreads fixed (lastcvid', lasttid') sofar' chosen sched s' threads' ref - Left failure -> writeRef fixed ref (Just $ Left failure) >> return (s, sofar) + Left failure -> writeRef (wref fixed) ref (Just $ Left failure) >> return (s, sofar) where (chosen, s') = if prior == -1 then (0, s) else sched s prior $ head runnable' :| tail runnable' @@ -241,10 +233,10 @@ runThreads fixed (lastcvid, lasttid) sofar prior sched s threads ref -- | Run a single thread one step, by dispatching on the type of -- 'Action'. -stepThread :: Monad n - => Action n r - -> Fixed n r -> (Scheduler s, s) -> (CVarId, ThreadId) -> ThreadId -> Threads n r -> n (Either Failure (Threads n r, ThreadAction)) -stepThread action fixed (scheduler, schedstate) (lastcvid, lasttid) tid threads = case action of +stepThread :: Monad n => Fixed n r + -> Action n r + -> (Scheduler s, s) -> (CVarId, ThreadId) -> ThreadId -> Threads n r -> n (Either Failure (Threads n r, ThreadAction)) +stepThread fixed action (scheduler, schedstate) (lastcvid, lasttid) tid threads = case action of AFork a b -> stepFork a b APut ref a c -> stepPut ref a c ATryPut ref a c -> stepTryPut ref a c @@ -278,7 +270,7 @@ stepThread action fixed (scheduler, schedstate) (lastcvid, lasttid) tid threads -- | Get the value from a @CVar@, without emptying, blocking the -- thread until it's full. stepGet ref c = do - (cvid, val, _) <- readRef fixed ref + (cvid, val, _) <- readRef (wref fixed) ref case val of Just val' -> return $ Right (goto (c val') tid threads, Read cvid) Nothing -> do @@ -325,7 +317,7 @@ stepThread action fixed (scheduler, schedstate) (lastcvid, lasttid) tid threads -- | Get the ID of a CVar getCVarId :: Monad n => Fixed n r -> R r a -> n CVarId -getCVarId fixed ref = (\(cvid,_,_) -> cvid) `liftM` readRef fixed ref +getCVarId fixed ref = (\(cvid,_,_) -> cvid) `liftM` readRef (wref fixed) ref -- | Put a value into a @CVar@, in either a blocking or nonblocking -- way. @@ -333,7 +325,7 @@ putIntoCVar :: Monad n => Bool -> R r a -> a -> (Bool -> Action n r) -> Fixed n r -> ThreadId -> Threads n r -> n (Bool, Threads n r, [ThreadId]) putIntoCVar blocking ref a c fixed threadid threads = do - (cvid, val, blocks) <- readRef fixed ref + (cvid, val, blocks) <- readRef (wref fixed) ref case val of Just _ @@ -345,7 +337,7 @@ putIntoCVar blocking ref a c fixed threadid threads = do return (False, goto (c False) threadid threads, []) Nothing -> do - writeRef fixed ref (cvid, Just a, blocks) + writeRef (wref fixed) ref (cvid, Just a, blocks) (threads', woken) <- wake fixed ref WaitFull threads return (True, goto (c True) threadid threads', woken) @@ -355,11 +347,11 @@ takeFromCVar :: Monad n => Bool -> R r a -> (Maybe a -> Action n r) -> Fixed n r -> ThreadId -> Threads n r -> n (Bool, Threads n r, [ThreadId]) takeFromCVar blocking ref c fixed threadid threads = do - (cvid, val, blocks) <- readRef fixed ref + (cvid, val, blocks) <- readRef (wref fixed) ref case val of Just _ -> do - writeRef fixed ref (cvid, Nothing, blocks) + writeRef (wref fixed) ref (cvid, Nothing, blocks) (threads', woken) <- wake fixed ref WaitEmpty threads return (True, goto (c val) threadid threads', woken) @@ -378,11 +370,11 @@ goto :: Action n r -> ThreadId -> Threads n r -> Threads n r goto a = M.alter $ \(Just (_, b)) -> Just (a, b) -- | Block a thread on a @CVar@. -block :: Monad n => Fixed n r - -> R r a -> (ThreadId -> Block) -> ThreadId -> Threads n r -> n (Threads n r) +block :: Monad n + => Fixed n r -> R r a -> (ThreadId -> Block) -> ThreadId -> Threads n r -> n (Threads n r) block fixed ref typ tid threads = do - (cvid, val, blocks) <- readRef fixed ref - writeRef fixed ref (cvid, val, typ tid : blocks) + (cvid, val, blocks) <- readRef (wref fixed) ref + writeRef (wref fixed) ref (cvid, val, typ tid : blocks) return $ M.alter (\(Just (a, _)) -> Just (a, True)) tid threads -- | Start a thread with the given ID. This must not already be in use! @@ -394,8 +386,8 @@ kill :: ThreadId -> Threads n r -> Threads n r kill = M.delete -- | Wake every thread blocked on a @CVar@ read/write. -wake :: Monad n => Fixed n r - -> R r a -> (ThreadId -> Block) -> Threads n r -> n (Threads n r, [ThreadId]) +wake :: Monad n + => Fixed n r -> R r a -> (ThreadId -> Block) -> Threads n r -> n (Threads n r, [ThreadId]) wake fixed ref typ m = do (m', woken) <- mapAndUnzipM wake' (M.toList m) @@ -404,10 +396,10 @@ wake fixed ref typ m = do where wake' a@(tid, (act, True)) = do let blck = typ tid - (cvid, val, blocks) <- readRef fixed ref + (cvid, val, blocks) <- readRef (wref fixed) ref if blck `elem` blocks - then writeRef fixed ref (cvid, val, filter (/= blck) blocks) >> return ((tid, (act, False)), Just tid) + then writeRef (wref fixed) ref (cvid, val, filter (/= blck) blocks) >> return ((tid, (act, False)), Just tid) else return (a, Nothing) wake' a = return (a, Nothing) diff --git a/Test/DejaFu/STM.hs b/Test/DejaFu/STM.hs new file mode 100755 index 0000000..37cd01b --- /dev/null +++ b/Test/DejaFu/STM.hs @@ -0,0 +1,233 @@ +{-# LANGUAGE ExistentialQuantification #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE Rank2Types #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeFamilies #-} + +-- | A 'MonadSTM' implementation, which can be run on top of 'IO' or +-- 'ST'. +module Test.DejaFu.STM + ( -- * The @STMLike@ Monad + STMLike + , Result(..) + , runTransaction + , runTransactionST + , runTransactionIO + , retry + , orElse + , check + + -- * @CTVar@s + , CTVar + , CTVarId + , initialCTVarId + , newCTVar + , readCTVar + , writeCTVar + ) where + +import Control.Applicative (Applicative) +import Control.Monad (liftM) +import Control.Monad.Cont (Cont, cont, runCont) +import Control.Monad.ST (ST, runST) +import Control.State +import Data.List (nub) +import Data.IORef (IORef) +import Data.STRef (STRef) + +import qualified Control.Monad.STM.Class as C + +-- | The 'MonadSTM' implementation, it encapsulates a single atomic +-- transaction. The environment, that is, the collection of defined +-- 'CTVar's is implicit, there is no list of them, they exist purely +-- as references. This makes the types simpler, but means you can't +-- really get an aggregate of them (if you ever wanted to for some +-- reason). +newtype STMLike t n r a = S { unS :: Cont (STMAction t n r) a } deriving (Functor, Applicative, Monad) + +instance Monad n => C.MonadSTM (STMLike t n r) where + type CTVar (STMLike t n r) = CTVar t r + + retry = retry + orElse = orElse + newCTVar = newCTVar + readCTVar = readCTVar + writeCTVar = writeCTVar + +-- | STM transactions are represented as a sequence of primitive +-- actions. +data STMAction t n r = + forall a. ARead (CTVar t r a) (a -> STMAction t n r) + | forall a. AWrite (CTVar t r a) a (STMAction t n r) + | forall a. AOrElse (STMLike t n r a) (STMLike t n r a) (a -> STMAction t n r) + | ANew (Ref n r -> CTVarId -> n (STMAction t n r)) + | ALift (n (STMAction t n r)) + | ARetry + | AStop + +type Fixed t n r = Wrapper n r (STMLike t n r) + +fixedST :: Fixed t (ST t) (STRef t) +fixedST = Wrapper refST lift where + lift ma = S $ cont (\c -> ALift $ c `liftM` ma) + +fixedIO :: Fixed t IO IORef +fixedIO = Wrapper refIO lift where + lift ma = S $ cont (\c -> ALift $ c `liftM` ma) + +-- | A 'CTVar' is a tuple of a unique ID and the value contained. The +-- ID is so that blocked transactions can be re-run when a 'CTVar' +-- they depend on has changed. +newtype CTVar t r a = V (CTVarId, r a) + +-- | The unique ID of a 'CTVar'. Only meaningful within a single +-- concurrent computation. +newtype CTVarId = I { unI :: Int } deriving (Eq, Ord, Show) + +-- | The initial 'CTVarId'. Use this when you start your computation, +-- but after that always use the latest value returned by running a +-- transaction, or you'll get funky wake-up behaviour. +initialCTVarId :: CTVarId +initialCTVarId = I 0 + +-- | Abort the current transaction, restoring any 'CTVar's written to, +-- and returning the list of 'CTVar's read. +retry :: Monad n => STMLike t n r a +retry = S $ cont $ const ARetry + +-- | Run the first transaction and, if it 'retry's, +orElse :: Monad n => STMLike t n r a -> STMLike t n r a -> STMLike t n r a +orElse a b = S $ cont $ \c -> AOrElse a b c + +-- | Check whether a condition is true and, if not, call 'retry'. +check :: Monad n => Bool -> STMLike t n r () +check = C.check + +-- | Create a new 'CTVar' containing the given value. +newCTVar :: Monad n => a -> STMLike t n r (CTVar t r a) +newCTVar a = S $ cont lifted where + lifted c = ANew $ \ref ctvid -> c `liftM` newCTVar' ref ctvid + newCTVar' ref ctvid = (\r -> V (ctvid, r)) `liftM` newRef ref a + +-- | Return the current value stored in a 'CTVar'. +readCTVar :: Monad n => CTVar t r a -> STMLike t n r a +readCTVar ctvar = S $ cont $ ARead ctvar + +-- | Write the supplied value into the 'CTVar'. +writeCTVar :: Monad n => CTVar t r a -> a -> STMLike t n r () +writeCTVar ctvar a = S $ cont $ \c -> AWrite ctvar a $ c () + +-- | The result of an STM transaction, along with which 'CTVar's it +-- touched whilst executing. +data Result a = + Success [CTVarId] a + -- ^ The transaction completed successfully, and mutated the returned 'CTVar's. + | Retry [CTVarId] + -- ^ The transaction aborted by calling 'retry', and read the + -- returned 'CTVar's. It should be retried when at least one of the + -- 'CTVar's has been mutated. + deriving (Show, Eq) + +-- | Run a transaction in the 'ST' monad, starting from a clean +-- environment, and discarding the environment afterwards. This is +-- suitable for testing individual transactions, but not for composing +-- multiple ones. +runTransaction :: (forall t. STMLike t (ST t) (STRef t) a) -> Result a +runTransaction ma = fst $ runST $ runTransactionST ma initialCTVarId + +-- | Run a transaction in the 'ST' monad, returning the result and new +-- initial 'CTVarId'. If the transaction ended by calling 'retry', any +-- 'CTVar' modifications are undone. +runTransactionST :: STMLike t (ST t) (STRef t) a -> CTVarId -> ST t (Result a, CTVarId) +runTransactionST ma ctvid = do + (res, undo, ctvid') <- doTransaction fixedST ma ctvid + + case res of + Retry _ -> undo + _ -> return () + + return (res, ctvid') + +-- | Run a transaction in the 'IO' monad, returning the result and new +-- initial 'CTVarId'. If the transaction ended by calling 'retry', any +-- 'CTVar' modifications are undone. +runTransactionIO :: STMLike t IO IORef a -> CTVarId -> IO (Result a, CTVarId) +runTransactionIO ma ctvid = do + (res, undo, ctvid') <- doTransaction fixedIO ma ctvid + + case res of + Retry _ -> undo + _ -> return () + + return (res, ctvid') + +-- | Run a STM transaction, returning an action to undo its effects. +doTransaction :: Monad n => Fixed t n r -> STMLike t n r a -> CTVarId -> n (Result a, n (), CTVarId) +doTransaction fixed ma newctvid = do + ref <- newRef (wref fixed) Nothing + + let c = runCont (unS $ ma >>= liftN fixed . writeRef (wref fixed) ref . Just) $ const AStop + + (newctvid', undo, readen, written) <- go ref c (return ()) newctvid [] [] + + res <- readRef (wref fixed) ref + + case res of + Just val -> return (Success (nub written) val, undo, newctvid') + Nothing -> undo >> return (Retry $ nub readen, undo, newctvid') + + where + go ref act undo nctvid readen written = do + (act', undo', nctvid', readen', written') <- stepTrans fixed act nctvid + case act' of + AStop -> return (nctvid', undo >> undo', readen' ++ readen, written' ++ written) + ARetry -> writeRef (wref fixed) ref Nothing >> return (nctvid', undo >> undo', readen' ++ readen, written' ++ written) + _ -> go ref act' (undo >> undo') nctvid' (readen' ++ readen) (written' ++ written) + +-- | Run a transaction for one step. +stepTrans :: forall t n r. Monad n => Fixed t n r -> STMAction t n r -> CTVarId -> n (STMAction t n r, n (), CTVarId, [CTVarId], [CTVarId]) +stepTrans fixed act newctvid = case act of + ARead ref c -> stepRead ref c + AWrite ref a c -> stepWrite ref a c + ANew na -> stepNew na + AOrElse a b c -> stepOrElse a b c + ALift na -> stepLift na + ARetry -> return (ARetry, nothing, newctvid, [], []) + AStop -> return (AStop, nothing, newctvid, [], []) + + where + nothing = return () + + stepRead :: CTVar t r a -> (a -> STMAction t n r) -> n (STMAction t n r, n (), CTVarId, [CTVarId], [CTVarId]) + stepRead (V (ctvid, ref)) c = do + val <- readRef (wref fixed) ref + return (c val, nothing, newctvid, [ctvid], []) + + stepWrite :: CTVar t r a -> a -> STMAction t n r -> n (STMAction t n r, n (), CTVarId, [CTVarId], [CTVarId]) + stepWrite (V (ctvid, ref)) a c = do + old <- readRef (wref fixed) ref + writeRef (wref fixed) ref a + return (c, writeRef (wref fixed) ref old, newctvid, [], [ctvid]) + + stepNew :: (Ref n r -> CTVarId -> n (STMAction t n r)) -> n (STMAction t n r, n (), CTVarId, [CTVarId], [CTVarId]) + stepNew na = do + let newctvid' = I $ unI newctvid + 1 + a <- na (wref fixed) newctvid' + return (a, nothing, newctvid', [], []) + + stepOrElse :: STMLike t n r a -> STMLike t n r a -> (a -> STMAction t n r) -> n (STMAction t n r, n (), CTVarId, [CTVarId], [CTVarId]) + stepOrElse a b c = do + (resa, undoa, newctvida') <- doTransaction fixed a newctvid + case resa of + Success written val -> return (c val, undoa, newctvida', [], written) + Retry _ -> do + undoa + (resb, undob, newctvidb') <- doTransaction fixed b newctvid + case resb of + Success written val -> return (c val, undob, newctvidb', [], written) + Retry readen -> return (ARetry, undob, newctvidb', readen, []) + + stepLift :: n (STMAction t n r) -> n (STMAction t n r, n (), CTVarId, [CTVarId], [CTVarId]) + stepLift na = do + a <- na + return (a, nothing, newctvid, [], []) diff --git a/dejafu.cabal b/dejafu.cabal index 6f70fc3..147d303 100755 --- a/dejafu.cabal +++ b/dejafu.cabal @@ -58,11 +58,13 @@ library , Test.DejaFu.Deterministic.IO , Test.DejaFu.Deterministic.Schedule , Test.DejaFu.SCT + , Test.DejaFu.STM other-modules: Test.DejaFu.Deterministic.Internal , Test.DejaFu.SCT.Bounding , Test.DejaFu.SCT.Internal + , Control.State , Data.List.Extra -- other-extensions: