Implement a STM runner.

This commit is contained in:
Michael Walker 2015-02-09 22:04:28 +00:00
parent f79f7fd245
commit 9b5e010d90
6 changed files with 308 additions and 52 deletions

37
Control/State.hs Executable file
View File

@ -0,0 +1,37 @@
{-# LANGUAGE Rank2Types #-}
-- | Dealing with mutable state.
module Control.State where
import Control.Monad.ST (ST)
import Data.IORef (IORef, newIORef, readIORef, writeIORef)
import Data.STRef (STRef, newSTRef, readSTRef, writeSTRef)
-- | Mutable references.
data Ref n r = Ref
{ newRef :: forall a. a -> n (r a)
, readRef :: forall a. r a -> n a
, writeRef :: forall a. r a -> a -> n ()
}
-- | Method dict for 'ST'.
refST :: Ref (ST t) (STRef t)
refST = Ref
{ newRef = newSTRef
, readRef = readSTRef
, writeRef = writeSTRef
}
-- | Method dict for 'IO'.
refIO :: Ref IO IORef
refIO = Ref
{ newRef = newIORef
, readRef = readIORef
, writeRef = writeIORef
}
-- | Wrapped mutable references.
data Wrapper n r m = Wrapper
{ wref :: Ref n r
, liftN :: forall a. n a -> m a
}

View File

@ -42,7 +42,8 @@ module Test.DejaFu.Deterministic
import Control.Applicative (Applicative(..), (<$>))
import Control.Monad.Cont (cont, runCont)
import Control.Monad.ST (ST, runST)
import Data.STRef (STRef, newSTRef, readSTRef, writeSTRef)
import Control.State (Wrapper(..), refST)
import Data.STRef (STRef, newSTRef)
import Test.DejaFu.Deterministic.Internal
import Test.DejaFu.Deterministic.Schedule
@ -67,12 +68,7 @@ instance C.MonadConc (Conc t) where
_concNoTest = _concNoTest
fixed :: Fixed (ST t) (STRef t)
fixed = F
{ newRef = newSTRef
, readRef = readSTRef
, writeRef = writeSTRef
, liftN = \ma -> cont (\c -> ALift $ c <$> ma)
}
fixed = Wrapper refST $ \ma -> cont (\c -> ALift $ c <$> ma)
-- | The concurrent variable type used with the 'Conc' monad. One
-- notable difference between these and 'MVar's is that 'MVar's are

View File

@ -45,7 +45,8 @@ module Test.DejaFu.Deterministic.IO
import Control.Applicative (Applicative(..), (<$>))
import Control.Monad.Cont (cont, runCont)
import Data.IORef (IORef, newIORef, readIORef, writeIORef)
import Control.State (Wrapper(..), refIO)
import Data.IORef (IORef, newIORef)
import Test.DejaFu.Deterministic.Internal
import Test.DejaFu.Deterministic.Schedule
@ -71,12 +72,7 @@ instance C.MonadConc (ConcIO t) where
_concNoTest = _concNoTest
fixed :: Fixed IO IORef
fixed = F
{ newRef = newIORef
, readRef = readIORef
, writeRef = writeIORef
, liftN = unC . liftIO
}
fixed = Wrapper refIO $ unC . liftIO
-- | The concurrent variable type used with the 'ConcIO' monad. These
-- behave the same as @Conc@'s @CVar@s

View File

@ -8,6 +8,7 @@ module Test.DejaFu.Deterministic.Internal where
import Control.DeepSeq (NFData(..))
import Control.Monad (liftM, mapAndUnzipM)
import Control.Monad.Cont (Cont, runCont)
import Control.State
import Data.List.Extra
import Data.Map (Map)
import Data.Maybe (catMaybes, fromJust, isNothing)
@ -23,17 +24,8 @@ type M n r a = Cont (Action n r) a
-- list of things blocked on it, and a unique numeric identifier.
type R r a = r (CVarId, Maybe a, [Block])
-- | Dict of methods for concrete implementations to override.
data Fixed n r = F
{ newRef :: forall a. a -> n (r a)
-- ^ Create a new reference
, readRef :: forall a. r a -> n a
-- ^ Read a reference.
, writeRef :: forall a. r a -> a -> n ()
-- ^ Overwrite the contents of a reference.
, liftN :: forall a. n a -> M n r a
-- ^ Lift an action from the underlying monad
}
-- | Dict of methods for implementations to override.
type Fixed n r = Wrapper n r (Cont (Action n r))
-- * Running @Conc@ Computations
@ -170,15 +162,15 @@ instance NFData Failure where
-- deadlock is detected. Also returned is the final state of the
-- scheduler, and an execution trace.
runFixed :: Monad n => Fixed n r
-> Scheduler s -> s -> M n r a -> n (Either Failure a, s, Trace)
-> Scheduler s -> s -> M n r a -> n (Either Failure a, s, Trace)
runFixed fixed sched s ma = do
ref <- newRef fixed Nothing
ref <- newRef (wref fixed) Nothing
let c = ma >>= liftN fixed . writeRef fixed ref . Just . Right
let c = ma >>= liftN fixed . writeRef (wref fixed) ref . Just . Right
let threads = M.fromList [(0, (runCont c $ const AStop, False))]
(s', trace) <- runThreads fixed (-1, 0) [] (negate 1) sched s threads ref
out <- readRef fixed ref
out <- readRef (wref fixed) ref
return (fromJust out, s', reverse trace)
@ -203,11 +195,11 @@ runThreads :: Monad n => Fixed n r
-> (CVarId, ThreadId) -> Trace -> ThreadId -> Scheduler s -> s -> Threads n r -> r (Maybe (Either Failure a)) -> n (s, Trace)
runThreads fixed (lastcvid, lasttid) sofar prior sched s threads ref
| isTerminated = return (s, sofar)
| isDeadlocked = writeRef fixed ref (Just $ Left Deadlock) >> return (s, sofar)
| isNonexistant = writeRef fixed ref (Just $ Left InternalError) >> return (s, sofar)
| isBlocked = writeRef fixed ref (Just $ Left InternalError) >> return (s, sofar)
| isDeadlocked = writeRef (wref fixed) ref (Just $ Left Deadlock) >> return (s, sofar)
| isNonexistant = writeRef (wref fixed) ref (Just $ Left InternalError) >> return (s, sofar)
| isBlocked = writeRef (wref fixed) ref (Just $ Left InternalError) >> return (s, sofar)
| otherwise = do
stepped <- stepThread (fst $ fromJust thread) fixed (sched, s) (lastcvid, lasttid) chosen threads
stepped <- stepThread fixed (fst $ fromJust thread) (sched, s) (lastcvid, lasttid) chosen threads
case stepped of
Right (threads', act) -> do
let sofar' = (decision, alternatives, act) : sofar
@ -217,7 +209,7 @@ runThreads fixed (lastcvid, lasttid) sofar prior sched s threads ref
runThreads fixed (lastcvid', lasttid') sofar' chosen sched s' threads' ref
Left failure -> writeRef fixed ref (Just $ Left failure) >> return (s, sofar)
Left failure -> writeRef (wref fixed) ref (Just $ Left failure) >> return (s, sofar)
where
(chosen, s') = if prior == -1 then (0, s) else sched s prior $ head runnable' :| tail runnable'
@ -241,10 +233,10 @@ runThreads fixed (lastcvid, lasttid) sofar prior sched s threads ref
-- | Run a single thread one step, by dispatching on the type of
-- 'Action'.
stepThread :: Monad n
=> Action n r
-> Fixed n r -> (Scheduler s, s) -> (CVarId, ThreadId) -> ThreadId -> Threads n r -> n (Either Failure (Threads n r, ThreadAction))
stepThread action fixed (scheduler, schedstate) (lastcvid, lasttid) tid threads = case action of
stepThread :: Monad n => Fixed n r
-> Action n r
-> (Scheduler s, s) -> (CVarId, ThreadId) -> ThreadId -> Threads n r -> n (Either Failure (Threads n r, ThreadAction))
stepThread fixed action (scheduler, schedstate) (lastcvid, lasttid) tid threads = case action of
AFork a b -> stepFork a b
APut ref a c -> stepPut ref a c
ATryPut ref a c -> stepTryPut ref a c
@ -278,7 +270,7 @@ stepThread action fixed (scheduler, schedstate) (lastcvid, lasttid) tid threads
-- | Get the value from a @CVar@, without emptying, blocking the
-- thread until it's full.
stepGet ref c = do
(cvid, val, _) <- readRef fixed ref
(cvid, val, _) <- readRef (wref fixed) ref
case val of
Just val' -> return $ Right (goto (c val') tid threads, Read cvid)
Nothing -> do
@ -325,7 +317,7 @@ stepThread action fixed (scheduler, schedstate) (lastcvid, lasttid) tid threads
-- | Get the ID of a CVar
getCVarId :: Monad n => Fixed n r -> R r a -> n CVarId
getCVarId fixed ref = (\(cvid,_,_) -> cvid) `liftM` readRef fixed ref
getCVarId fixed ref = (\(cvid,_,_) -> cvid) `liftM` readRef (wref fixed) ref
-- | Put a value into a @CVar@, in either a blocking or nonblocking
-- way.
@ -333,7 +325,7 @@ putIntoCVar :: Monad n
=> Bool -> R r a -> a -> (Bool -> Action n r)
-> Fixed n r -> ThreadId -> Threads n r -> n (Bool, Threads n r, [ThreadId])
putIntoCVar blocking ref a c fixed threadid threads = do
(cvid, val, blocks) <- readRef fixed ref
(cvid, val, blocks) <- readRef (wref fixed) ref
case val of
Just _
@ -345,7 +337,7 @@ putIntoCVar blocking ref a c fixed threadid threads = do
return (False, goto (c False) threadid threads, [])
Nothing -> do
writeRef fixed ref (cvid, Just a, blocks)
writeRef (wref fixed) ref (cvid, Just a, blocks)
(threads', woken) <- wake fixed ref WaitFull threads
return (True, goto (c True) threadid threads', woken)
@ -355,11 +347,11 @@ takeFromCVar :: Monad n
=> Bool -> R r a -> (Maybe a -> Action n r)
-> Fixed n r -> ThreadId -> Threads n r -> n (Bool, Threads n r, [ThreadId])
takeFromCVar blocking ref c fixed threadid threads = do
(cvid, val, blocks) <- readRef fixed ref
(cvid, val, blocks) <- readRef (wref fixed) ref
case val of
Just _ -> do
writeRef fixed ref (cvid, Nothing, blocks)
writeRef (wref fixed) ref (cvid, Nothing, blocks)
(threads', woken) <- wake fixed ref WaitEmpty threads
return (True, goto (c val) threadid threads', woken)
@ -378,11 +370,11 @@ goto :: Action n r -> ThreadId -> Threads n r -> Threads n r
goto a = M.alter $ \(Just (_, b)) -> Just (a, b)
-- | Block a thread on a @CVar@.
block :: Monad n => Fixed n r
-> R r a -> (ThreadId -> Block) -> ThreadId -> Threads n r -> n (Threads n r)
block :: Monad n
=> Fixed n r -> R r a -> (ThreadId -> Block) -> ThreadId -> Threads n r -> n (Threads n r)
block fixed ref typ tid threads = do
(cvid, val, blocks) <- readRef fixed ref
writeRef fixed ref (cvid, val, typ tid : blocks)
(cvid, val, blocks) <- readRef (wref fixed) ref
writeRef (wref fixed) ref (cvid, val, typ tid : blocks)
return $ M.alter (\(Just (a, _)) -> Just (a, True)) tid threads
-- | Start a thread with the given ID. This must not already be in use!
@ -394,8 +386,8 @@ kill :: ThreadId -> Threads n r -> Threads n r
kill = M.delete
-- | Wake every thread blocked on a @CVar@ read/write.
wake :: Monad n => Fixed n r
-> R r a -> (ThreadId -> Block) -> Threads n r -> n (Threads n r, [ThreadId])
wake :: Monad n
=> Fixed n r -> R r a -> (ThreadId -> Block) -> Threads n r -> n (Threads n r, [ThreadId])
wake fixed ref typ m = do
(m', woken) <- mapAndUnzipM wake' (M.toList m)
@ -404,10 +396,10 @@ wake fixed ref typ m = do
where
wake' a@(tid, (act, True)) = do
let blck = typ tid
(cvid, val, blocks) <- readRef fixed ref
(cvid, val, blocks) <- readRef (wref fixed) ref
if blck `elem` blocks
then writeRef fixed ref (cvid, val, filter (/= blck) blocks) >> return ((tid, (act, False)), Just tid)
then writeRef (wref fixed) ref (cvid, val, filter (/= blck) blocks) >> return ((tid, (act, False)), Just tid)
else return (a, Nothing)
wake' a = return (a, Nothing)

233
Test/DejaFu/STM.hs Executable file
View File

@ -0,0 +1,233 @@
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
-- | A 'MonadSTM' implementation, which can be run on top of 'IO' or
-- 'ST'.
module Test.DejaFu.STM
( -- * The @STMLike@ Monad
STMLike
, Result(..)
, runTransaction
, runTransactionST
, runTransactionIO
, retry
, orElse
, check
-- * @CTVar@s
, CTVar
, CTVarId
, initialCTVarId
, newCTVar
, readCTVar
, writeCTVar
) where
import Control.Applicative (Applicative)
import Control.Monad (liftM)
import Control.Monad.Cont (Cont, cont, runCont)
import Control.Monad.ST (ST, runST)
import Control.State
import Data.List (nub)
import Data.IORef (IORef)
import Data.STRef (STRef)
import qualified Control.Monad.STM.Class as C
-- | The 'MonadSTM' implementation, it encapsulates a single atomic
-- transaction. The environment, that is, the collection of defined
-- 'CTVar's is implicit, there is no list of them, they exist purely
-- as references. This makes the types simpler, but means you can't
-- really get an aggregate of them (if you ever wanted to for some
-- reason).
newtype STMLike t n r a = S { unS :: Cont (STMAction t n r) a } deriving (Functor, Applicative, Monad)
instance Monad n => C.MonadSTM (STMLike t n r) where
type CTVar (STMLike t n r) = CTVar t r
retry = retry
orElse = orElse
newCTVar = newCTVar
readCTVar = readCTVar
writeCTVar = writeCTVar
-- | STM transactions are represented as a sequence of primitive
-- actions.
data STMAction t n r =
forall a. ARead (CTVar t r a) (a -> STMAction t n r)
| forall a. AWrite (CTVar t r a) a (STMAction t n r)
| forall a. AOrElse (STMLike t n r a) (STMLike t n r a) (a -> STMAction t n r)
| ANew (Ref n r -> CTVarId -> n (STMAction t n r))
| ALift (n (STMAction t n r))
| ARetry
| AStop
type Fixed t n r = Wrapper n r (STMLike t n r)
fixedST :: Fixed t (ST t) (STRef t)
fixedST = Wrapper refST lift where
lift ma = S $ cont (\c -> ALift $ c `liftM` ma)
fixedIO :: Fixed t IO IORef
fixedIO = Wrapper refIO lift where
lift ma = S $ cont (\c -> ALift $ c `liftM` ma)
-- | A 'CTVar' is a tuple of a unique ID and the value contained. The
-- ID is so that blocked transactions can be re-run when a 'CTVar'
-- they depend on has changed.
newtype CTVar t r a = V (CTVarId, r a)
-- | The unique ID of a 'CTVar'. Only meaningful within a single
-- concurrent computation.
newtype CTVarId = I { unI :: Int } deriving (Eq, Ord, Show)
-- | The initial 'CTVarId'. Use this when you start your computation,
-- but after that always use the latest value returned by running a
-- transaction, or you'll get funky wake-up behaviour.
initialCTVarId :: CTVarId
initialCTVarId = I 0
-- | Abort the current transaction, restoring any 'CTVar's written to,
-- and returning the list of 'CTVar's read.
retry :: Monad n => STMLike t n r a
retry = S $ cont $ const ARetry
-- | Run the first transaction and, if it 'retry's,
orElse :: Monad n => STMLike t n r a -> STMLike t n r a -> STMLike t n r a
orElse a b = S $ cont $ \c -> AOrElse a b c
-- | Check whether a condition is true and, if not, call 'retry'.
check :: Monad n => Bool -> STMLike t n r ()
check = C.check
-- | Create a new 'CTVar' containing the given value.
newCTVar :: Monad n => a -> STMLike t n r (CTVar t r a)
newCTVar a = S $ cont lifted where
lifted c = ANew $ \ref ctvid -> c `liftM` newCTVar' ref ctvid
newCTVar' ref ctvid = (\r -> V (ctvid, r)) `liftM` newRef ref a
-- | Return the current value stored in a 'CTVar'.
readCTVar :: Monad n => CTVar t r a -> STMLike t n r a
readCTVar ctvar = S $ cont $ ARead ctvar
-- | Write the supplied value into the 'CTVar'.
writeCTVar :: Monad n => CTVar t r a -> a -> STMLike t n r ()
writeCTVar ctvar a = S $ cont $ \c -> AWrite ctvar a $ c ()
-- | The result of an STM transaction, along with which 'CTVar's it
-- touched whilst executing.
data Result a =
Success [CTVarId] a
-- ^ The transaction completed successfully, and mutated the returned 'CTVar's.
| Retry [CTVarId]
-- ^ The transaction aborted by calling 'retry', and read the
-- returned 'CTVar's. It should be retried when at least one of the
-- 'CTVar's has been mutated.
deriving (Show, Eq)
-- | Run a transaction in the 'ST' monad, starting from a clean
-- environment, and discarding the environment afterwards. This is
-- suitable for testing individual transactions, but not for composing
-- multiple ones.
runTransaction :: (forall t. STMLike t (ST t) (STRef t) a) -> Result a
runTransaction ma = fst $ runST $ runTransactionST ma initialCTVarId
-- | Run a transaction in the 'ST' monad, returning the result and new
-- initial 'CTVarId'. If the transaction ended by calling 'retry', any
-- 'CTVar' modifications are undone.
runTransactionST :: STMLike t (ST t) (STRef t) a -> CTVarId -> ST t (Result a, CTVarId)
runTransactionST ma ctvid = do
(res, undo, ctvid') <- doTransaction fixedST ma ctvid
case res of
Retry _ -> undo
_ -> return ()
return (res, ctvid')
-- | Run a transaction in the 'IO' monad, returning the result and new
-- initial 'CTVarId'. If the transaction ended by calling 'retry', any
-- 'CTVar' modifications are undone.
runTransactionIO :: STMLike t IO IORef a -> CTVarId -> IO (Result a, CTVarId)
runTransactionIO ma ctvid = do
(res, undo, ctvid') <- doTransaction fixedIO ma ctvid
case res of
Retry _ -> undo
_ -> return ()
return (res, ctvid')
-- | Run a STM transaction, returning an action to undo its effects.
doTransaction :: Monad n => Fixed t n r -> STMLike t n r a -> CTVarId -> n (Result a, n (), CTVarId)
doTransaction fixed ma newctvid = do
ref <- newRef (wref fixed) Nothing
let c = runCont (unS $ ma >>= liftN fixed . writeRef (wref fixed) ref . Just) $ const AStop
(newctvid', undo, readen, written) <- go ref c (return ()) newctvid [] []
res <- readRef (wref fixed) ref
case res of
Just val -> return (Success (nub written) val, undo, newctvid')
Nothing -> undo >> return (Retry $ nub readen, undo, newctvid')
where
go ref act undo nctvid readen written = do
(act', undo', nctvid', readen', written') <- stepTrans fixed act nctvid
case act' of
AStop -> return (nctvid', undo >> undo', readen' ++ readen, written' ++ written)
ARetry -> writeRef (wref fixed) ref Nothing >> return (nctvid', undo >> undo', readen' ++ readen, written' ++ written)
_ -> go ref act' (undo >> undo') nctvid' (readen' ++ readen) (written' ++ written)
-- | Run a transaction for one step.
stepTrans :: forall t n r. Monad n => Fixed t n r -> STMAction t n r -> CTVarId -> n (STMAction t n r, n (), CTVarId, [CTVarId], [CTVarId])
stepTrans fixed act newctvid = case act of
ARead ref c -> stepRead ref c
AWrite ref a c -> stepWrite ref a c
ANew na -> stepNew na
AOrElse a b c -> stepOrElse a b c
ALift na -> stepLift na
ARetry -> return (ARetry, nothing, newctvid, [], [])
AStop -> return (AStop, nothing, newctvid, [], [])
where
nothing = return ()
stepRead :: CTVar t r a -> (a -> STMAction t n r) -> n (STMAction t n r, n (), CTVarId, [CTVarId], [CTVarId])
stepRead (V (ctvid, ref)) c = do
val <- readRef (wref fixed) ref
return (c val, nothing, newctvid, [ctvid], [])
stepWrite :: CTVar t r a -> a -> STMAction t n r -> n (STMAction t n r, n (), CTVarId, [CTVarId], [CTVarId])
stepWrite (V (ctvid, ref)) a c = do
old <- readRef (wref fixed) ref
writeRef (wref fixed) ref a
return (c, writeRef (wref fixed) ref old, newctvid, [], [ctvid])
stepNew :: (Ref n r -> CTVarId -> n (STMAction t n r)) -> n (STMAction t n r, n (), CTVarId, [CTVarId], [CTVarId])
stepNew na = do
let newctvid' = I $ unI newctvid + 1
a <- na (wref fixed) newctvid'
return (a, nothing, newctvid', [], [])
stepOrElse :: STMLike t n r a -> STMLike t n r a -> (a -> STMAction t n r) -> n (STMAction t n r, n (), CTVarId, [CTVarId], [CTVarId])
stepOrElse a b c = do
(resa, undoa, newctvida') <- doTransaction fixed a newctvid
case resa of
Success written val -> return (c val, undoa, newctvida', [], written)
Retry _ -> do
undoa
(resb, undob, newctvidb') <- doTransaction fixed b newctvid
case resb of
Success written val -> return (c val, undob, newctvidb', [], written)
Retry readen -> return (ARetry, undob, newctvidb', readen, [])
stepLift :: n (STMAction t n r) -> n (STMAction t n r, n (), CTVarId, [CTVarId], [CTVarId])
stepLift na = do
a <- na
return (a, nothing, newctvid, [], [])

View File

@ -58,11 +58,13 @@ library
, Test.DejaFu.Deterministic.IO
, Test.DejaFu.Deterministic.Schedule
, Test.DejaFu.SCT
, Test.DejaFu.STM
other-modules: Test.DejaFu.Deterministic.Internal
, Test.DejaFu.SCT.Bounding
, Test.DejaFu.SCT.Internal
, Control.State
, Data.List.Extra
-- other-extensions: