Fix monadic state capture and restore for concurrent tasks

This causes up to 30% regression in async stream generation benchmarks and up
to 200% regression in async nested benchmarks. Mostly, due to an additional
functional call that cannot be inlined.
This commit is contained in:
Harendra Kumar 2018-09-16 14:14:37 +05:30
parent ba5a8c44b8
commit e810658dfe
7 changed files with 241 additions and 44 deletions

View File

@ -5,6 +5,12 @@
* Fixed a livelock in ahead style streams. The problem manifests sometimes when
multiple streams are merged together in ahead style and one of them is a nil
stream.
* As per expected concurrency semantics each forked concurrent task must run
with the monadic state captured at the fork point. This release fixes a bug,
which, in some cases caused an incorrect monadic state to be used for a
concurrent action, leading to unexpected behavior when concurrent streams are
used in a stateful monad e.g. `StateT`. Particularly, this bug cannot affect
`ReaderT`.
## 0.5.1

View File

@ -8,6 +8,7 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE UnboxedTuples #-}
@ -49,6 +50,8 @@ module Streamly.SVar
-- SVar related
, newAheadVar
, newParallelVar
, captureMonadState
, RunInIO (..)
, atomicModifyIORefCAS
, WorkerInfo (..)
@ -113,7 +116,7 @@ import Control.Exception
import Control.Monad (when)
import Control.Monad.Catch (MonadThrow)
import Control.Monad.IO.Class (MonadIO(..))
import Control.Monad.Trans.Control (MonadBaseControl, control)
import Control.Monad.Trans.Control (MonadBaseControl, control, StM)
import Data.Atomics
(casIORef, readForCAS, peekTicket, atomicModifyIORefCAS_,
writeBarrier, storeLoadBarrier)
@ -325,7 +328,8 @@ data Limit = Unlimited | Limited Word deriving Show
data SVar t m a = SVar
{
-- Read only state
svarStyle :: SVarStyle
svarStyle :: SVarStyle
, svarMrun :: RunInIO m
-- Shared output queue (events, length)
-- XXX For better efficiency we can try a preallocated array type (perhaps
@ -777,6 +781,18 @@ ringDoorBell sv = do
-- @since 0.1.0
type MonadAsync m = (MonadIO m, MonadBaseControl IO m, MonadThrow m)
-- When we run computations concurrently, we completely isolate the state of
-- the concurrent computations from the parent computation. The invariant is
-- that we should never be running two concurrent computations in the same
-- thread without using the runInIO function. Also, we should never be running
-- a concurrent computation in the parent thread, otherwise it may affect the
-- state of the parent which is against the defined semantics of concurrent
-- execution.
newtype RunInIO m = RunInIO { runInIO :: forall b. m b -> IO (StM m b) }
captureMonadState :: MonadBaseControl IO m => m (RunInIO m)
captureMonadState = control $ \run -> run (return $ RunInIO run)
-- Stolen from the async package. The perf improvement is modest, 2% on a
-- thread heavy benchmark (parallel composition using noop computations).
-- A version of forkIO that does not include the outer exception
@ -790,14 +806,15 @@ rawForkIO action = IO $ \ s ->
{-# INLINE doFork #-}
doFork :: MonadBaseControl IO m
=> m ()
-> RunInIO m
-> (SomeException -> IO ())
-> m ThreadId
doFork action exHandler =
control $ \runInIO ->
doFork action (RunInIO mrun) exHandler =
control $ \run ->
mask $ \restore -> do
tid <- rawForkIO $ catch (restore $ void $ runInIO action)
tid <- rawForkIO $ catch (restore $ void $ mrun action)
exHandler
runInIO (return tid)
run (return tid)
-- XXX Can we make access to remainingWork and yieldRateInfo fields in sv
-- faster, along with the fields in sv required by send?
@ -1288,7 +1305,8 @@ pushWorker yieldMax sv = do
, workerYieldCount = cntRef
, workerLatencyStart = lat
}
doFork (workLoop sv winfo) (handleChildException sv) >>= addThread sv
doFork (workLoop sv winfo) (svarMrun sv) (handleChildException sv)
>>= addThread sv
-- XXX we can push the workerCount modification in accountThread and use the
-- same pushWorker for Parallel case as well.
@ -1305,7 +1323,8 @@ pushWorkerPar
pushWorkerPar sv wloop =
if svarInspectMode sv
then forkWithDiag
else doFork (wloop Nothing) (handleChildException sv) >>= modifyThread sv
else doFork (wloop Nothing) (svarMrun sv) (handleChildException sv)
>>= modifyThread sv
where
@ -1330,7 +1349,8 @@ pushWorkerPar sv wloop =
, workerLatencyStart = lat
}
doFork (wloop winfo) (handleChildException sv) >>= modifyThread sv
doFork (wloop winfo) (svarMrun sv) (handleChildException sv)
>>= modifyThread sv
-- Returns:
-- True: can dispatch more
@ -1997,8 +2017,9 @@ getAheadSVar :: MonadAsync m
-> SVar t m a
-> Maybe WorkerInfo
-> m ())
-> RunInIO m
-> IO (SVar t m a)
getAheadSVar st f = do
getAheadSVar st f mrun = do
outQ <- newIORef ([], 0)
-- the second component of the tuple is "Nothing" when heap is being
-- cleared, "Just n" when we are expecting sequence number n to arrive
@ -2036,6 +2057,7 @@ getAheadSVar st f = do
, isQueueDone = isQueueDoneAhead sv q
, needDoorBell = wfw
, svarStyle = AheadVar
, svarMrun = mrun
, workerCount = active
, accountThread = delThread sv
, workerStopMVar = stopMVar
@ -2081,8 +2103,8 @@ getAheadSVar st f = do
(xs, _) <- readIORef q
return $ null xs
getParallelSVar :: MonadIO m => State t m a -> IO (SVar t m a)
getParallelSVar st = do
getParallelSVar :: MonadIO m => State t m a -> RunInIO m -> IO (SVar t m a)
getParallelSVar st mrun = do
outQ <- newIORef ([], 0)
outQMv <- newEmptyMVar
active <- newIORef 0
@ -2112,6 +2134,7 @@ getParallelSVar st = do
, isQueueDone = undefined
, needDoorBell = undefined
, svarStyle = ParallelVar
, svarMrun = mrun
, workerCount = active
, accountThread = modifyThread sv
, workerStopMVar = undefined
@ -2160,12 +2183,15 @@ newAheadVar :: MonadAsync m
-> m ())
-> m (SVar t m a)
newAheadVar st m wloop = do
sv <- liftIO $ getAheadSVar st wloop
mrun <- captureMonadState
sv <- liftIO $ getAheadSVar st wloop mrun
sendFirstWorker sv m
{-# INLINABLE newParallelVar #-}
newParallelVar :: MonadAsync m => State t m a -> m (SVar t m a)
newParallelVar st = liftIO $ getParallelSVar st
newParallelVar st = do
mrun <- captureMonadState
liftIO $ getParallelSVar st mrun
-- XXX this errors out for Parallel/Ahead SVars
-- | Write a stream to an 'SVar' in a non-blocking manner. The stream can then

View File

@ -47,7 +47,9 @@ import qualified Data.Heap as H
import Streamly.Streams.SVar (fromSVar)
import Streamly.Streams.Serial (map)
import Streamly.SVar
import Streamly.Streams.StreamK (IsStream(..), Stream(..))
import Streamly.Streams.StreamK
(IsStream(..), Stream(..), unstreamShared, unStreamIsolated,
runStreamSVar)
import qualified Streamly.Streams.StreamK as K
import Prelude hiding (map)
@ -296,7 +298,7 @@ processHeap q heap st sv winfo entry sno stopping = loopHeap sno entry
let stop = do
liftIO (incrementYieldLimit sv)
nextHeap seqNo
unStream r st stop
runStreamSVar sv r st stop
(singleStreamFromHeap seqNo)
(yieldStreamFromHeap seqNo)
else liftIO $ do
@ -346,7 +348,7 @@ processWithoutToken q heap st sv winfo m seqNo = do
-- we stop.
toHeap AheadEntryNull
unStream m st stop
runStreamSVar sv m st stop
(toHeap . AheadEntryPure)
(\a r -> toHeap $ AheadEntryStream $ K.cons a r)
@ -406,7 +408,7 @@ processWithToken q heap st sv winfo action sno = do
liftIO (incrementYieldLimit sv)
loopWithToken (sno + 1)
unStream action st stop (singleOutput sno) (yieldOutput sno)
runStreamSVar sv action st stop (singleOutput sno) (yieldOutput sno)
where
@ -429,7 +431,7 @@ processWithToken q heap st sv winfo action sno = do
let stop = do
liftIO (incrementYieldLimit sv)
loopWithToken (seqNo + 1)
unStream r st stop
runStreamSVar sv r st stop
(singleOutput seqNo)
(yieldOutput seqNo)
else do
@ -458,7 +460,7 @@ processWithToken q heap st sv winfo action sno = do
let stop = do
liftIO (incrementYieldLimit sv)
loopWithToken (seqNo + 1)
unStream m st stop
runStreamSVar sv m st stop
(singleOutput seqNo)
(yieldOutput seqNo)
else
@ -681,10 +683,13 @@ aheadbind m f = go m
where
go (Stream g) =
Stream $ \st stp sng yld ->
let run x = unStream x st stp sng yld
single a = run $ f a
yieldk a r = run $ f a `aheadS` go r
in g (rstState st) stp single yieldk
let runShared x = unstreamShared x st stp sng yld
runIsolated x = unStreamIsolated x st stp sng yld
single a = runIsolated $ f a
yieldk a r = runShared $
K.isolateStream (f a) `aheadS` go r
in g (rstState st) stp single yieldk
instance MonadAsync m => Monad (AheadT m) where
return = pure

View File

@ -57,7 +57,7 @@ import qualified Data.Set as S
import Streamly.Streams.SVar (fromSVar)
import Streamly.Streams.Serial (map)
import Streamly.SVar
import Streamly.Streams.StreamK (IsStream(..), Stream(..), adapt)
import Streamly.Streams.StreamK (IsStream(..), Stream(..), adapt, runStreamSVar)
import qualified Streamly.Streams.StreamK as K
#include "Instances.hs"
@ -82,7 +82,7 @@ workLoopLIFO q st sv winfo = run
work <- dequeue
case work of
Nothing -> liftIO $ sendStop sv winfo
Just m -> unStream m st run single yieldk
Just m -> runStreamSVar sv m st run single yieldk
single a = do
res <- liftIO $ sendYield sv winfo (ChildYield a)
@ -91,7 +91,7 @@ workLoopLIFO q st sv winfo = run
yieldk a r = do
res <- liftIO $ sendYield sv winfo (ChildYield a)
if res
then unStream r st run single yieldk
then runStreamSVar sv r st run single yieldk
else liftIO $ do
enqueueLIFO sv q r
sendStop sv winfo
@ -132,7 +132,7 @@ workLoopLIFOLimited q st sv winfo = run
if yieldLimitOk
then do
let stop = liftIO (incrementYieldLimit sv) >> run
unStream m st stop single yieldk
runStreamSVar sv m st stop single yieldk
-- Avoid any side effects, undo the yield limit decrement if we
-- never yielded anything.
else liftIO $ do
@ -151,7 +151,7 @@ workLoopLIFOLimited q st sv winfo = run
yieldLimitOk <- liftIO $ decrementYieldLimit sv
let stop = liftIO (incrementYieldLimit sv) >> run
if res && yieldLimitOk
then unStream r st stop single yieldk
then runStreamSVar sv r st stop single yieldk
else liftIO $ do
incrementYieldLimit sv
enqueueLIFO sv q r
@ -183,7 +183,7 @@ workLoopFIFO q st sv winfo = run
work <- liftIO $ tryPopR q
case work of
Nothing -> liftIO $ sendStop sv winfo
Just m -> unStream m st run single yieldk
Just m -> runStreamSVar sv m st run single yieldk
single a = do
res <- liftIO $ sendYield sv winfo (ChildYield a)
@ -192,7 +192,7 @@ workLoopFIFO q st sv winfo = run
yieldk a r = do
res <- liftIO $ sendYield sv winfo (ChildYield a)
if res
then unStream r st run single yieldk
then runStreamSVar sv r st run single yieldk
else liftIO $ do
enqueueFIFO sv q r
sendStop sv winfo
@ -218,7 +218,7 @@ workLoopFIFOLimited q st sv winfo = run
if yieldLimitOk
then do
let stop = liftIO (incrementYieldLimit sv) >> run
unStream m st stop single yieldk
runStreamSVar sv m st stop single yieldk
else liftIO $ do
enqueueFIFO sv q m
incrementYieldLimit sv
@ -233,7 +233,7 @@ workLoopFIFOLimited q st sv winfo = run
yieldLimitOk <- liftIO $ decrementYieldLimit sv
let stop = liftIO (incrementYieldLimit sv) >> run
if res && yieldLimitOk
then unStream r st stop single yieldk
then runStreamSVar sv r st stop single yieldk
else liftIO $ do
incrementYieldLimit sv
enqueueFIFO sv q r
@ -249,8 +249,8 @@ workLoopFIFOLimited q st sv winfo = run
-- than 10%. Need to investigate what the root cause is.
-- Interestingly, the same thing does not make any difference for Ahead.
getLifoSVar :: forall m a. MonadAsync m
=> State Stream m a -> IO (SVar Stream m a)
getLifoSVar st = do
=> State Stream m a -> RunInIO m -> IO (SVar Stream m a)
getLifoSVar st mrun = do
outQ <- newIORef ([], 0)
outQMv <- newEmptyMVar
active <- newIORef 0
@ -303,6 +303,7 @@ getLifoSVar st = do
, isQueueDone = workDone sv
, needDoorBell = wfw
, svarStyle = AsyncVar
, svarMrun = mrun
, workerCount = active
, accountThread = delThread sv
, workerStopMVar = undefined
@ -339,8 +340,8 @@ getLifoSVar st = do
in return sv
getFifoSVar :: forall m a. MonadAsync m
=> State Stream m a -> IO (SVar Stream m a)
getFifoSVar st = do
=> State Stream m a -> RunInIO m -> IO (SVar Stream m a)
getFifoSVar st mrun = do
outQ <- newIORef ([], 0)
outQMv <- newEmptyMVar
active <- newIORef 0
@ -392,6 +393,7 @@ getFifoSVar st = do
, isQueueDone = workDone sv
, needDoorBell = wfw
, svarStyle = WAsyncVar
, svarMrun = mrun
, workerCount = active
, accountThread = delThread sv
, workerStopMVar = undefined
@ -431,7 +433,8 @@ getFifoSVar st = do
newAsyncVar :: MonadAsync m
=> State Stream m a -> Stream m a -> m (SVar Stream m a)
newAsyncVar st m = do
sv <- liftIO $ getLifoSVar st
mrun <- captureMonadState
sv <- liftIO $ getLifoSVar st mrun
sendFirstWorker sv m
-- XXX Get rid of this?
@ -455,7 +458,8 @@ mkAsync' st m = fmap fromSVar (newAsyncVar st (toStream m))
newWAsyncVar :: MonadAsync m
=> State Stream m a -> Stream m a -> m (SVar Stream m a)
newWAsyncVar st m = do
sv <- liftIO $ getFifoSVar st
mrun <- captureMonadState
sv <- liftIO $ getFifoSVar st mrun
sendFirstWorker sv m
------------------------------------------------------------------------------

View File

@ -67,6 +67,7 @@ runOne st m winfo = unStream m st stop single yieldk
where
sv = fromJust $ streamVar st
mrun = runInIO $ svarMrun sv
withLimitCheck action = do
yieldLimitOk <- liftIO $ decrementYieldLimitPost sv
@ -82,7 +83,8 @@ runOne st m winfo = unStream m st stop single yieldk
-- queue and queue it back on that and exit the thread when the outputQueue
-- overflows. Parallel is dangerous because it can accumulate unbounded
-- output in the buffer.
yieldk a r = void (sendit a) >> withLimitCheck (runOne st r winfo)
yieldk a r = void (sendit a)
>> withLimitCheck (void $ liftIO $ mrun $ runOne st r winfo)
{-# NOINLINE forkSVarPar #-}
forkSVarPar :: MonadAsync m => Stream m a -> Stream m a -> Stream m a

View File

@ -33,6 +33,10 @@ module Streamly.Streams.StreamK
-- * The stream type
, Stream (..)
, unStreamIsolated
, isolateStream
, unstreamShared
, runStreamSVar
-- * Construction
, mkStream
@ -140,6 +144,7 @@ module Streamly.Streams.StreamK
where
import Control.Monad (void)
import Control.Monad.IO.Class (MonadIO(liftIO))
import Control.Monad.Reader.Class (MonadReader(..))
import Control.Monad.Trans.Class (MonadTrans(lift))
import Data.Semigroup (Semigroup(..))
@ -183,6 +188,51 @@ newtype Stream m a =
-> m r
}
-- XXX make this the default "unStream"
-- | unwraps the Stream type producing the stream function that can be run with
-- continuations.
{-# INLINE unStreamIsolated #-}
unStreamIsolated ::
Stream m a
-> State Stream m a -- state
-> m r -- stop
-> (a -> m r) -- singleton
-> (a -> Stream m a -> m r) -- yield
-> m r
unStreamIsolated x st = unStream x (rstState st)
{-# INLINE isolateStream #-}
isolateStream :: Stream m a -> Stream m a
isolateStream x = Stream $ \st stp sng yld ->
unStreamIsolated x st stp sng yld
-- | Like unstream, but passes a shared SVar across continuations.
{-# INLINE unstreamShared #-}
unstreamShared ::
Stream m a
-> State Stream m a -- state
-> m r -- stop
-> (a -> m r) -- singleton
-> (a -> Stream m a -> m r) -- yield
-> m r
unstreamShared = unStream
-- Run the stream using a run function associated with the SVar that runs the
-- streams with a captured snapshot of the monadic state.
{-# INLINE runStreamSVar #-}
runStreamSVar
:: MonadIO m
=> SVar Stream m a
-> Stream m a
-> State Stream m a -- state
-> m r -- stop
-> (a -> m r) -- singleton
-> (a -> Stream m a -> m r) -- yield
-> m ()
runStreamSVar sv m st stp sng yld =
let mrun = runInIO $ svarMrun sv
in void $ liftIO $ mrun $ unStream m st stp sng yld
------------------------------------------------------------------------------
-- Types that can behave as a Stream
------------------------------------------------------------------------------
@ -969,10 +1019,12 @@ bindWith par m f = go m
where
go (Stream g) =
Stream $ \st stp sng yld ->
let run x = unStream x st stp sng yld
single a = run $ f a
yieldk a r = run $ f a `par` go r
in g (rstState st) stp single yieldk
let runShared x = unstreamShared x st stp sng yld
runIsolated x = unStreamIsolated x st stp sng yld
single a = runIsolated $ f a
yieldk a r = runShared $ isolateStream (f a) `par` go r
in g (rstState st) stp single yieldk
------------------------------------------------------------------------------
-- Alternative & MonadPlus

View File

@ -7,8 +7,11 @@ module Main (main) where
import Control.Concurrent (threadDelay)
import Control.Exception (Exception, try, ErrorCall(..), catch, throw)
import Control.Monad (void)
import Control.Monad.Catch (throwM, MonadThrow)
import Control.Monad.Error.Class (throwError, MonadError)
import Control.Monad.IO.Class (MonadIO(liftIO))
import Control.Monad.State (MonadState, get, modify, runStateT, StateT)
import Control.Monad.Trans.Except (runExceptT, ExceptT)
import Data.Foldable (forM_, fold)
import Data.List (sort)
@ -495,6 +498,34 @@ parallelTests = H.parallel $ do
it "foldlM' is strict enough" (checkFoldMStrictness foldlM'StrictCheck)
it "scanlM' is strict enough" (checkScanlMStrictness scanlM'StrictCheck)
---------------------------------------------------------------------------
-- Monadic state snapshot in concurrent tasks
---------------------------------------------------------------------------
it "asyncly maintains independent states in concurrent tasks"
(monadicStateSnapshot asyncly)
it "asyncly limited maintains independent states in concurrent tasks"
(monadicStateSnapshot (asyncly . S.take 10000))
it "wAsyncly maintains independent states in concurrent tasks"
(monadicStateSnapshot wAsyncly)
it "wAsyncly limited maintains independent states in concurrent tasks"
(monadicStateSnapshot (wAsyncly . S.take 10000))
it "aheadly maintains independent states in concurrent tasks"
(monadicStateSnapshot aheadly)
it "aheadly limited maintains independent states in concurrent tasks"
(monadicStateSnapshot (aheadly . S.take 10000))
it "parallely maintains independent states in concurrent tasks"
(monadicStateSnapshot parallely)
it "async maintains independent states in concurrent tasks"
(monadicStateSnapshotOp async)
it "ahead maintains independent states in concurrent tasks"
(monadicStateSnapshotOp ahead)
it "wAsync maintains independent states in concurrent tasks"
(monadicStateSnapshotOp wAsync)
it "parallel maintains independent states in concurrent tasks"
(monadicStateSnapshotOp Streamly.parallel)
---------------------------------------------------------------------------
-- Slower tests are at the end
---------------------------------------------------------------------------
@ -513,6 +544,77 @@ parallelTests = H.parallel $ do
replicate 4000 $ S.yieldM $ threadDelay 1000000)
`shouldReturn` ()
-- Each snapshot carries an independent state. Multiple parallel tasks should
-- not affect each other's state. This is especially important when we run
-- multiple tasks in a single thread.
snapshot :: (IsStream t, MonadAsync m, MonadState Int m) => t m ()
snapshot =
-- We deliberately use a replicate count 1 here, because a lower count
-- catches problems that a higher count doesn't.
S.replicateM 1 $ do
-- Even though we modify the state here it should not reflect in other
-- parallel tasks, it is local to each concurrent task.
modify (+1) >> get >>= liftIO . (`shouldSatisfy` (==1))
modify (+1) >> get >>= liftIO . (`shouldSatisfy` (==2))
snapshot1 :: (IsStream t, MonadAsync m, MonadState Int m) => t m ()
snapshot1 = S.replicateM 1000 $
modify (+1) >> get >>= liftIO . (`shouldSatisfy` (==2))
snapshot2 :: (IsStream t, MonadAsync m, MonadState Int m) => t m ()
snapshot2 = S.replicateM 1000 $
modify (+1) >> get >>= liftIO . (`shouldSatisfy` (==2))
stateComp
:: ( IsStream t
, MonadAsync m
, Semigroup (t m ())
, MonadIO (t m)
, MonadState Int m
, MonadState Int (t m)
)
=> t m ()
stateComp = do
-- Each task in a concurrent composition inherits the state and maintains
-- its own modifications to it, not affecting the parent computation.
snapshot <> (modify (+1) >> (snapshot1 <> snapshot2))
-- The above modify statement does not affect our state because that is
-- used in a parallel composition. In a serial composition it will affect
-- our state.
get >>= liftIO . (`shouldSatisfy` (== (0 :: Int)))
monadicStateSnapshot
:: ( IsStream t
, Semigroup (t (StateT Int IO) ())
, MonadIO (t (StateT Int IO))
, MonadState Int (t (StateT Int IO))
)
=> (t (StateT Int IO) () -> SerialT (StateT Int IO) ()) -> IO ()
monadicStateSnapshot t = void $ runStateT (runStream $ t stateComp) 0
stateCompOp
:: ( AsyncT (StateT Int IO) ()
-> AsyncT (StateT Int IO) ()
-> AsyncT (StateT Int IO) ()
)
-> SerialT (StateT Int IO) ()
stateCompOp op = do
-- Each task in a concurrent composition inherits the state and maintains
-- its own modifications to it, not affecting the parent computation.
asyncly (snapshot `op` (modify (+1) >> (snapshot1 `op` snapshot2)))
-- The above modify statement does not affect our state because that is
-- used in a parallel composition. In a serial composition it will affect
-- our state.
get >>= liftIO . (`shouldSatisfy` (== (0 :: Int)))
monadicStateSnapshotOp
:: ( AsyncT (StateT Int IO) ()
-> AsyncT (StateT Int IO) ()
-> AsyncT (StateT Int IO) ()
)
-> IO ()
monadicStateSnapshotOp op = void $ runStateT (runStream $ stateCompOp op) 0
takeCombined :: (Monad m, Semigroup (t m Int), Show a, Eq a, IsStream t)
=> Int -> (t m Int -> SerialT IO a) -> IO ()
takeCombined n t = do