Fix monadic state capture and restore for concurrent tasks

This causes up to 30% regression in async stream generation benchmarks and up
to 200% regression in async nested benchmarks. Mostly, due to an additional
functional call that cannot be inlined.
This commit is contained in:
Harendra Kumar 2018-09-16 14:14:37 +05:30
parent ba5a8c44b8
commit e810658dfe
7 changed files with 241 additions and 44 deletions

View File

@ -5,6 +5,12 @@
* Fixed a livelock in ahead style streams. The problem manifests sometimes when * Fixed a livelock in ahead style streams. The problem manifests sometimes when
multiple streams are merged together in ahead style and one of them is a nil multiple streams are merged together in ahead style and one of them is a nil
stream. stream.
* As per expected concurrency semantics each forked concurrent task must run
with the monadic state captured at the fork point. This release fixes a bug,
which, in some cases caused an incorrect monadic state to be used for a
concurrent action, leading to unexpected behavior when concurrent streams are
used in a stateful monad e.g. `StateT`. Particularly, this bug cannot affect
`ReaderT`.
## 0.5.1 ## 0.5.1

View File

@ -8,6 +8,7 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MagicHash #-} {-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE UnboxedTuples #-} {-# LANGUAGE UnboxedTuples #-}
@ -49,6 +50,8 @@ module Streamly.SVar
-- SVar related -- SVar related
, newAheadVar , newAheadVar
, newParallelVar , newParallelVar
, captureMonadState
, RunInIO (..)
, atomicModifyIORefCAS , atomicModifyIORefCAS
, WorkerInfo (..) , WorkerInfo (..)
@ -113,7 +116,7 @@ import Control.Exception
import Control.Monad (when) import Control.Monad (when)
import Control.Monad.Catch (MonadThrow) import Control.Monad.Catch (MonadThrow)
import Control.Monad.IO.Class (MonadIO(..)) import Control.Monad.IO.Class (MonadIO(..))
import Control.Monad.Trans.Control (MonadBaseControl, control) import Control.Monad.Trans.Control (MonadBaseControl, control, StM)
import Data.Atomics import Data.Atomics
(casIORef, readForCAS, peekTicket, atomicModifyIORefCAS_, (casIORef, readForCAS, peekTicket, atomicModifyIORefCAS_,
writeBarrier, storeLoadBarrier) writeBarrier, storeLoadBarrier)
@ -325,7 +328,8 @@ data Limit = Unlimited | Limited Word deriving Show
data SVar t m a = SVar data SVar t m a = SVar
{ {
-- Read only state -- Read only state
svarStyle :: SVarStyle svarStyle :: SVarStyle
, svarMrun :: RunInIO m
-- Shared output queue (events, length) -- Shared output queue (events, length)
-- XXX For better efficiency we can try a preallocated array type (perhaps -- XXX For better efficiency we can try a preallocated array type (perhaps
@ -777,6 +781,18 @@ ringDoorBell sv = do
-- @since 0.1.0 -- @since 0.1.0
type MonadAsync m = (MonadIO m, MonadBaseControl IO m, MonadThrow m) type MonadAsync m = (MonadIO m, MonadBaseControl IO m, MonadThrow m)
-- When we run computations concurrently, we completely isolate the state of
-- the concurrent computations from the parent computation. The invariant is
-- that we should never be running two concurrent computations in the same
-- thread without using the runInIO function. Also, we should never be running
-- a concurrent computation in the parent thread, otherwise it may affect the
-- state of the parent which is against the defined semantics of concurrent
-- execution.
newtype RunInIO m = RunInIO { runInIO :: forall b. m b -> IO (StM m b) }
captureMonadState :: MonadBaseControl IO m => m (RunInIO m)
captureMonadState = control $ \run -> run (return $ RunInIO run)
-- Stolen from the async package. The perf improvement is modest, 2% on a -- Stolen from the async package. The perf improvement is modest, 2% on a
-- thread heavy benchmark (parallel composition using noop computations). -- thread heavy benchmark (parallel composition using noop computations).
-- A version of forkIO that does not include the outer exception -- A version of forkIO that does not include the outer exception
@ -790,14 +806,15 @@ rawForkIO action = IO $ \ s ->
{-# INLINE doFork #-} {-# INLINE doFork #-}
doFork :: MonadBaseControl IO m doFork :: MonadBaseControl IO m
=> m () => m ()
-> RunInIO m
-> (SomeException -> IO ()) -> (SomeException -> IO ())
-> m ThreadId -> m ThreadId
doFork action exHandler = doFork action (RunInIO mrun) exHandler =
control $ \runInIO -> control $ \run ->
mask $ \restore -> do mask $ \restore -> do
tid <- rawForkIO $ catch (restore $ void $ runInIO action) tid <- rawForkIO $ catch (restore $ void $ mrun action)
exHandler exHandler
runInIO (return tid) run (return tid)
-- XXX Can we make access to remainingWork and yieldRateInfo fields in sv -- XXX Can we make access to remainingWork and yieldRateInfo fields in sv
-- faster, along with the fields in sv required by send? -- faster, along with the fields in sv required by send?
@ -1288,7 +1305,8 @@ pushWorker yieldMax sv = do
, workerYieldCount = cntRef , workerYieldCount = cntRef
, workerLatencyStart = lat , workerLatencyStart = lat
} }
doFork (workLoop sv winfo) (handleChildException sv) >>= addThread sv doFork (workLoop sv winfo) (svarMrun sv) (handleChildException sv)
>>= addThread sv
-- XXX we can push the workerCount modification in accountThread and use the -- XXX we can push the workerCount modification in accountThread and use the
-- same pushWorker for Parallel case as well. -- same pushWorker for Parallel case as well.
@ -1305,7 +1323,8 @@ pushWorkerPar
pushWorkerPar sv wloop = pushWorkerPar sv wloop =
if svarInspectMode sv if svarInspectMode sv
then forkWithDiag then forkWithDiag
else doFork (wloop Nothing) (handleChildException sv) >>= modifyThread sv else doFork (wloop Nothing) (svarMrun sv) (handleChildException sv)
>>= modifyThread sv
where where
@ -1330,7 +1349,8 @@ pushWorkerPar sv wloop =
, workerLatencyStart = lat , workerLatencyStart = lat
} }
doFork (wloop winfo) (handleChildException sv) >>= modifyThread sv doFork (wloop winfo) (svarMrun sv) (handleChildException sv)
>>= modifyThread sv
-- Returns: -- Returns:
-- True: can dispatch more -- True: can dispatch more
@ -1997,8 +2017,9 @@ getAheadSVar :: MonadAsync m
-> SVar t m a -> SVar t m a
-> Maybe WorkerInfo -> Maybe WorkerInfo
-> m ()) -> m ())
-> RunInIO m
-> IO (SVar t m a) -> IO (SVar t m a)
getAheadSVar st f = do getAheadSVar st f mrun = do
outQ <- newIORef ([], 0) outQ <- newIORef ([], 0)
-- the second component of the tuple is "Nothing" when heap is being -- the second component of the tuple is "Nothing" when heap is being
-- cleared, "Just n" when we are expecting sequence number n to arrive -- cleared, "Just n" when we are expecting sequence number n to arrive
@ -2036,6 +2057,7 @@ getAheadSVar st f = do
, isQueueDone = isQueueDoneAhead sv q , isQueueDone = isQueueDoneAhead sv q
, needDoorBell = wfw , needDoorBell = wfw
, svarStyle = AheadVar , svarStyle = AheadVar
, svarMrun = mrun
, workerCount = active , workerCount = active
, accountThread = delThread sv , accountThread = delThread sv
, workerStopMVar = stopMVar , workerStopMVar = stopMVar
@ -2081,8 +2103,8 @@ getAheadSVar st f = do
(xs, _) <- readIORef q (xs, _) <- readIORef q
return $ null xs return $ null xs
getParallelSVar :: MonadIO m => State t m a -> IO (SVar t m a) getParallelSVar :: MonadIO m => State t m a -> RunInIO m -> IO (SVar t m a)
getParallelSVar st = do getParallelSVar st mrun = do
outQ <- newIORef ([], 0) outQ <- newIORef ([], 0)
outQMv <- newEmptyMVar outQMv <- newEmptyMVar
active <- newIORef 0 active <- newIORef 0
@ -2112,6 +2134,7 @@ getParallelSVar st = do
, isQueueDone = undefined , isQueueDone = undefined
, needDoorBell = undefined , needDoorBell = undefined
, svarStyle = ParallelVar , svarStyle = ParallelVar
, svarMrun = mrun
, workerCount = active , workerCount = active
, accountThread = modifyThread sv , accountThread = modifyThread sv
, workerStopMVar = undefined , workerStopMVar = undefined
@ -2160,12 +2183,15 @@ newAheadVar :: MonadAsync m
-> m ()) -> m ())
-> m (SVar t m a) -> m (SVar t m a)
newAheadVar st m wloop = do newAheadVar st m wloop = do
sv <- liftIO $ getAheadSVar st wloop mrun <- captureMonadState
sv <- liftIO $ getAheadSVar st wloop mrun
sendFirstWorker sv m sendFirstWorker sv m
{-# INLINABLE newParallelVar #-} {-# INLINABLE newParallelVar #-}
newParallelVar :: MonadAsync m => State t m a -> m (SVar t m a) newParallelVar :: MonadAsync m => State t m a -> m (SVar t m a)
newParallelVar st = liftIO $ getParallelSVar st newParallelVar st = do
mrun <- captureMonadState
liftIO $ getParallelSVar st mrun
-- XXX this errors out for Parallel/Ahead SVars -- XXX this errors out for Parallel/Ahead SVars
-- | Write a stream to an 'SVar' in a non-blocking manner. The stream can then -- | Write a stream to an 'SVar' in a non-blocking manner. The stream can then

View File

@ -47,7 +47,9 @@ import qualified Data.Heap as H
import Streamly.Streams.SVar (fromSVar) import Streamly.Streams.SVar (fromSVar)
import Streamly.Streams.Serial (map) import Streamly.Streams.Serial (map)
import Streamly.SVar import Streamly.SVar
import Streamly.Streams.StreamK (IsStream(..), Stream(..)) import Streamly.Streams.StreamK
(IsStream(..), Stream(..), unstreamShared, unStreamIsolated,
runStreamSVar)
import qualified Streamly.Streams.StreamK as K import qualified Streamly.Streams.StreamK as K
import Prelude hiding (map) import Prelude hiding (map)
@ -296,7 +298,7 @@ processHeap q heap st sv winfo entry sno stopping = loopHeap sno entry
let stop = do let stop = do
liftIO (incrementYieldLimit sv) liftIO (incrementYieldLimit sv)
nextHeap seqNo nextHeap seqNo
unStream r st stop runStreamSVar sv r st stop
(singleStreamFromHeap seqNo) (singleStreamFromHeap seqNo)
(yieldStreamFromHeap seqNo) (yieldStreamFromHeap seqNo)
else liftIO $ do else liftIO $ do
@ -346,7 +348,7 @@ processWithoutToken q heap st sv winfo m seqNo = do
-- we stop. -- we stop.
toHeap AheadEntryNull toHeap AheadEntryNull
unStream m st stop runStreamSVar sv m st stop
(toHeap . AheadEntryPure) (toHeap . AheadEntryPure)
(\a r -> toHeap $ AheadEntryStream $ K.cons a r) (\a r -> toHeap $ AheadEntryStream $ K.cons a r)
@ -406,7 +408,7 @@ processWithToken q heap st sv winfo action sno = do
liftIO (incrementYieldLimit sv) liftIO (incrementYieldLimit sv)
loopWithToken (sno + 1) loopWithToken (sno + 1)
unStream action st stop (singleOutput sno) (yieldOutput sno) runStreamSVar sv action st stop (singleOutput sno) (yieldOutput sno)
where where
@ -429,7 +431,7 @@ processWithToken q heap st sv winfo action sno = do
let stop = do let stop = do
liftIO (incrementYieldLimit sv) liftIO (incrementYieldLimit sv)
loopWithToken (seqNo + 1) loopWithToken (seqNo + 1)
unStream r st stop runStreamSVar sv r st stop
(singleOutput seqNo) (singleOutput seqNo)
(yieldOutput seqNo) (yieldOutput seqNo)
else do else do
@ -458,7 +460,7 @@ processWithToken q heap st sv winfo action sno = do
let stop = do let stop = do
liftIO (incrementYieldLimit sv) liftIO (incrementYieldLimit sv)
loopWithToken (seqNo + 1) loopWithToken (seqNo + 1)
unStream m st stop runStreamSVar sv m st stop
(singleOutput seqNo) (singleOutput seqNo)
(yieldOutput seqNo) (yieldOutput seqNo)
else else
@ -681,10 +683,13 @@ aheadbind m f = go m
where where
go (Stream g) = go (Stream g) =
Stream $ \st stp sng yld -> Stream $ \st stp sng yld ->
let run x = unStream x st stp sng yld let runShared x = unstreamShared x st stp sng yld
single a = run $ f a runIsolated x = unStreamIsolated x st stp sng yld
yieldk a r = run $ f a `aheadS` go r
in g (rstState st) stp single yieldk single a = runIsolated $ f a
yieldk a r = runShared $
K.isolateStream (f a) `aheadS` go r
in g (rstState st) stp single yieldk
instance MonadAsync m => Monad (AheadT m) where instance MonadAsync m => Monad (AheadT m) where
return = pure return = pure

View File

@ -57,7 +57,7 @@ import qualified Data.Set as S
import Streamly.Streams.SVar (fromSVar) import Streamly.Streams.SVar (fromSVar)
import Streamly.Streams.Serial (map) import Streamly.Streams.Serial (map)
import Streamly.SVar import Streamly.SVar
import Streamly.Streams.StreamK (IsStream(..), Stream(..), adapt) import Streamly.Streams.StreamK (IsStream(..), Stream(..), adapt, runStreamSVar)
import qualified Streamly.Streams.StreamK as K import qualified Streamly.Streams.StreamK as K
#include "Instances.hs" #include "Instances.hs"
@ -82,7 +82,7 @@ workLoopLIFO q st sv winfo = run
work <- dequeue work <- dequeue
case work of case work of
Nothing -> liftIO $ sendStop sv winfo Nothing -> liftIO $ sendStop sv winfo
Just m -> unStream m st run single yieldk Just m -> runStreamSVar sv m st run single yieldk
single a = do single a = do
res <- liftIO $ sendYield sv winfo (ChildYield a) res <- liftIO $ sendYield sv winfo (ChildYield a)
@ -91,7 +91,7 @@ workLoopLIFO q st sv winfo = run
yieldk a r = do yieldk a r = do
res <- liftIO $ sendYield sv winfo (ChildYield a) res <- liftIO $ sendYield sv winfo (ChildYield a)
if res if res
then unStream r st run single yieldk then runStreamSVar sv r st run single yieldk
else liftIO $ do else liftIO $ do
enqueueLIFO sv q r enqueueLIFO sv q r
sendStop sv winfo sendStop sv winfo
@ -132,7 +132,7 @@ workLoopLIFOLimited q st sv winfo = run
if yieldLimitOk if yieldLimitOk
then do then do
let stop = liftIO (incrementYieldLimit sv) >> run let stop = liftIO (incrementYieldLimit sv) >> run
unStream m st stop single yieldk runStreamSVar sv m st stop single yieldk
-- Avoid any side effects, undo the yield limit decrement if we -- Avoid any side effects, undo the yield limit decrement if we
-- never yielded anything. -- never yielded anything.
else liftIO $ do else liftIO $ do
@ -151,7 +151,7 @@ workLoopLIFOLimited q st sv winfo = run
yieldLimitOk <- liftIO $ decrementYieldLimit sv yieldLimitOk <- liftIO $ decrementYieldLimit sv
let stop = liftIO (incrementYieldLimit sv) >> run let stop = liftIO (incrementYieldLimit sv) >> run
if res && yieldLimitOk if res && yieldLimitOk
then unStream r st stop single yieldk then runStreamSVar sv r st stop single yieldk
else liftIO $ do else liftIO $ do
incrementYieldLimit sv incrementYieldLimit sv
enqueueLIFO sv q r enqueueLIFO sv q r
@ -183,7 +183,7 @@ workLoopFIFO q st sv winfo = run
work <- liftIO $ tryPopR q work <- liftIO $ tryPopR q
case work of case work of
Nothing -> liftIO $ sendStop sv winfo Nothing -> liftIO $ sendStop sv winfo
Just m -> unStream m st run single yieldk Just m -> runStreamSVar sv m st run single yieldk
single a = do single a = do
res <- liftIO $ sendYield sv winfo (ChildYield a) res <- liftIO $ sendYield sv winfo (ChildYield a)
@ -192,7 +192,7 @@ workLoopFIFO q st sv winfo = run
yieldk a r = do yieldk a r = do
res <- liftIO $ sendYield sv winfo (ChildYield a) res <- liftIO $ sendYield sv winfo (ChildYield a)
if res if res
then unStream r st run single yieldk then runStreamSVar sv r st run single yieldk
else liftIO $ do else liftIO $ do
enqueueFIFO sv q r enqueueFIFO sv q r
sendStop sv winfo sendStop sv winfo
@ -218,7 +218,7 @@ workLoopFIFOLimited q st sv winfo = run
if yieldLimitOk if yieldLimitOk
then do then do
let stop = liftIO (incrementYieldLimit sv) >> run let stop = liftIO (incrementYieldLimit sv) >> run
unStream m st stop single yieldk runStreamSVar sv m st stop single yieldk
else liftIO $ do else liftIO $ do
enqueueFIFO sv q m enqueueFIFO sv q m
incrementYieldLimit sv incrementYieldLimit sv
@ -233,7 +233,7 @@ workLoopFIFOLimited q st sv winfo = run
yieldLimitOk <- liftIO $ decrementYieldLimit sv yieldLimitOk <- liftIO $ decrementYieldLimit sv
let stop = liftIO (incrementYieldLimit sv) >> run let stop = liftIO (incrementYieldLimit sv) >> run
if res && yieldLimitOk if res && yieldLimitOk
then unStream r st stop single yieldk then runStreamSVar sv r st stop single yieldk
else liftIO $ do else liftIO $ do
incrementYieldLimit sv incrementYieldLimit sv
enqueueFIFO sv q r enqueueFIFO sv q r
@ -249,8 +249,8 @@ workLoopFIFOLimited q st sv winfo = run
-- than 10%. Need to investigate what the root cause is. -- than 10%. Need to investigate what the root cause is.
-- Interestingly, the same thing does not make any difference for Ahead. -- Interestingly, the same thing does not make any difference for Ahead.
getLifoSVar :: forall m a. MonadAsync m getLifoSVar :: forall m a. MonadAsync m
=> State Stream m a -> IO (SVar Stream m a) => State Stream m a -> RunInIO m -> IO (SVar Stream m a)
getLifoSVar st = do getLifoSVar st mrun = do
outQ <- newIORef ([], 0) outQ <- newIORef ([], 0)
outQMv <- newEmptyMVar outQMv <- newEmptyMVar
active <- newIORef 0 active <- newIORef 0
@ -303,6 +303,7 @@ getLifoSVar st = do
, isQueueDone = workDone sv , isQueueDone = workDone sv
, needDoorBell = wfw , needDoorBell = wfw
, svarStyle = AsyncVar , svarStyle = AsyncVar
, svarMrun = mrun
, workerCount = active , workerCount = active
, accountThread = delThread sv , accountThread = delThread sv
, workerStopMVar = undefined , workerStopMVar = undefined
@ -339,8 +340,8 @@ getLifoSVar st = do
in return sv in return sv
getFifoSVar :: forall m a. MonadAsync m getFifoSVar :: forall m a. MonadAsync m
=> State Stream m a -> IO (SVar Stream m a) => State Stream m a -> RunInIO m -> IO (SVar Stream m a)
getFifoSVar st = do getFifoSVar st mrun = do
outQ <- newIORef ([], 0) outQ <- newIORef ([], 0)
outQMv <- newEmptyMVar outQMv <- newEmptyMVar
active <- newIORef 0 active <- newIORef 0
@ -392,6 +393,7 @@ getFifoSVar st = do
, isQueueDone = workDone sv , isQueueDone = workDone sv
, needDoorBell = wfw , needDoorBell = wfw
, svarStyle = WAsyncVar , svarStyle = WAsyncVar
, svarMrun = mrun
, workerCount = active , workerCount = active
, accountThread = delThread sv , accountThread = delThread sv
, workerStopMVar = undefined , workerStopMVar = undefined
@ -431,7 +433,8 @@ getFifoSVar st = do
newAsyncVar :: MonadAsync m newAsyncVar :: MonadAsync m
=> State Stream m a -> Stream m a -> m (SVar Stream m a) => State Stream m a -> Stream m a -> m (SVar Stream m a)
newAsyncVar st m = do newAsyncVar st m = do
sv <- liftIO $ getLifoSVar st mrun <- captureMonadState
sv <- liftIO $ getLifoSVar st mrun
sendFirstWorker sv m sendFirstWorker sv m
-- XXX Get rid of this? -- XXX Get rid of this?
@ -455,7 +458,8 @@ mkAsync' st m = fmap fromSVar (newAsyncVar st (toStream m))
newWAsyncVar :: MonadAsync m newWAsyncVar :: MonadAsync m
=> State Stream m a -> Stream m a -> m (SVar Stream m a) => State Stream m a -> Stream m a -> m (SVar Stream m a)
newWAsyncVar st m = do newWAsyncVar st m = do
sv <- liftIO $ getFifoSVar st mrun <- captureMonadState
sv <- liftIO $ getFifoSVar st mrun
sendFirstWorker sv m sendFirstWorker sv m
------------------------------------------------------------------------------ ------------------------------------------------------------------------------

View File

@ -67,6 +67,7 @@ runOne st m winfo = unStream m st stop single yieldk
where where
sv = fromJust $ streamVar st sv = fromJust $ streamVar st
mrun = runInIO $ svarMrun sv
withLimitCheck action = do withLimitCheck action = do
yieldLimitOk <- liftIO $ decrementYieldLimitPost sv yieldLimitOk <- liftIO $ decrementYieldLimitPost sv
@ -82,7 +83,8 @@ runOne st m winfo = unStream m st stop single yieldk
-- queue and queue it back on that and exit the thread when the outputQueue -- queue and queue it back on that and exit the thread when the outputQueue
-- overflows. Parallel is dangerous because it can accumulate unbounded -- overflows. Parallel is dangerous because it can accumulate unbounded
-- output in the buffer. -- output in the buffer.
yieldk a r = void (sendit a) >> withLimitCheck (runOne st r winfo) yieldk a r = void (sendit a)
>> withLimitCheck (void $ liftIO $ mrun $ runOne st r winfo)
{-# NOINLINE forkSVarPar #-} {-# NOINLINE forkSVarPar #-}
forkSVarPar :: MonadAsync m => Stream m a -> Stream m a -> Stream m a forkSVarPar :: MonadAsync m => Stream m a -> Stream m a -> Stream m a

View File

@ -33,6 +33,10 @@ module Streamly.Streams.StreamK
-- * The stream type -- * The stream type
, Stream (..) , Stream (..)
, unStreamIsolated
, isolateStream
, unstreamShared
, runStreamSVar
-- * Construction -- * Construction
, mkStream , mkStream
@ -140,6 +144,7 @@ module Streamly.Streams.StreamK
where where
import Control.Monad (void) import Control.Monad (void)
import Control.Monad.IO.Class (MonadIO(liftIO))
import Control.Monad.Reader.Class (MonadReader(..)) import Control.Monad.Reader.Class (MonadReader(..))
import Control.Monad.Trans.Class (MonadTrans(lift)) import Control.Monad.Trans.Class (MonadTrans(lift))
import Data.Semigroup (Semigroup(..)) import Data.Semigroup (Semigroup(..))
@ -183,6 +188,51 @@ newtype Stream m a =
-> m r -> m r
} }
-- XXX make this the default "unStream"
-- | unwraps the Stream type producing the stream function that can be run with
-- continuations.
{-# INLINE unStreamIsolated #-}
unStreamIsolated ::
Stream m a
-> State Stream m a -- state
-> m r -- stop
-> (a -> m r) -- singleton
-> (a -> Stream m a -> m r) -- yield
-> m r
unStreamIsolated x st = unStream x (rstState st)
{-# INLINE isolateStream #-}
isolateStream :: Stream m a -> Stream m a
isolateStream x = Stream $ \st stp sng yld ->
unStreamIsolated x st stp sng yld
-- | Like unstream, but passes a shared SVar across continuations.
{-# INLINE unstreamShared #-}
unstreamShared ::
Stream m a
-> State Stream m a -- state
-> m r -- stop
-> (a -> m r) -- singleton
-> (a -> Stream m a -> m r) -- yield
-> m r
unstreamShared = unStream
-- Run the stream using a run function associated with the SVar that runs the
-- streams with a captured snapshot of the monadic state.
{-# INLINE runStreamSVar #-}
runStreamSVar
:: MonadIO m
=> SVar Stream m a
-> Stream m a
-> State Stream m a -- state
-> m r -- stop
-> (a -> m r) -- singleton
-> (a -> Stream m a -> m r) -- yield
-> m ()
runStreamSVar sv m st stp sng yld =
let mrun = runInIO $ svarMrun sv
in void $ liftIO $ mrun $ unStream m st stp sng yld
------------------------------------------------------------------------------ ------------------------------------------------------------------------------
-- Types that can behave as a Stream -- Types that can behave as a Stream
------------------------------------------------------------------------------ ------------------------------------------------------------------------------
@ -969,10 +1019,12 @@ bindWith par m f = go m
where where
go (Stream g) = go (Stream g) =
Stream $ \st stp sng yld -> Stream $ \st stp sng yld ->
let run x = unStream x st stp sng yld let runShared x = unstreamShared x st stp sng yld
single a = run $ f a runIsolated x = unStreamIsolated x st stp sng yld
yieldk a r = run $ f a `par` go r
in g (rstState st) stp single yieldk single a = runIsolated $ f a
yieldk a r = runShared $ isolateStream (f a) `par` go r
in g (rstState st) stp single yieldk
------------------------------------------------------------------------------ ------------------------------------------------------------------------------
-- Alternative & MonadPlus -- Alternative & MonadPlus

View File

@ -7,8 +7,11 @@ module Main (main) where
import Control.Concurrent (threadDelay) import Control.Concurrent (threadDelay)
import Control.Exception (Exception, try, ErrorCall(..), catch, throw) import Control.Exception (Exception, try, ErrorCall(..), catch, throw)
import Control.Monad (void)
import Control.Monad.Catch (throwM, MonadThrow) import Control.Monad.Catch (throwM, MonadThrow)
import Control.Monad.Error.Class (throwError, MonadError) import Control.Monad.Error.Class (throwError, MonadError)
import Control.Monad.IO.Class (MonadIO(liftIO))
import Control.Monad.State (MonadState, get, modify, runStateT, StateT)
import Control.Monad.Trans.Except (runExceptT, ExceptT) import Control.Monad.Trans.Except (runExceptT, ExceptT)
import Data.Foldable (forM_, fold) import Data.Foldable (forM_, fold)
import Data.List (sort) import Data.List (sort)
@ -495,6 +498,34 @@ parallelTests = H.parallel $ do
it "foldlM' is strict enough" (checkFoldMStrictness foldlM'StrictCheck) it "foldlM' is strict enough" (checkFoldMStrictness foldlM'StrictCheck)
it "scanlM' is strict enough" (checkScanlMStrictness scanlM'StrictCheck) it "scanlM' is strict enough" (checkScanlMStrictness scanlM'StrictCheck)
---------------------------------------------------------------------------
-- Monadic state snapshot in concurrent tasks
---------------------------------------------------------------------------
it "asyncly maintains independent states in concurrent tasks"
(monadicStateSnapshot asyncly)
it "asyncly limited maintains independent states in concurrent tasks"
(monadicStateSnapshot (asyncly . S.take 10000))
it "wAsyncly maintains independent states in concurrent tasks"
(monadicStateSnapshot wAsyncly)
it "wAsyncly limited maintains independent states in concurrent tasks"
(monadicStateSnapshot (wAsyncly . S.take 10000))
it "aheadly maintains independent states in concurrent tasks"
(monadicStateSnapshot aheadly)
it "aheadly limited maintains independent states in concurrent tasks"
(monadicStateSnapshot (aheadly . S.take 10000))
it "parallely maintains independent states in concurrent tasks"
(monadicStateSnapshot parallely)
it "async maintains independent states in concurrent tasks"
(monadicStateSnapshotOp async)
it "ahead maintains independent states in concurrent tasks"
(monadicStateSnapshotOp ahead)
it "wAsync maintains independent states in concurrent tasks"
(monadicStateSnapshotOp wAsync)
it "parallel maintains independent states in concurrent tasks"
(monadicStateSnapshotOp Streamly.parallel)
--------------------------------------------------------------------------- ---------------------------------------------------------------------------
-- Slower tests are at the end -- Slower tests are at the end
--------------------------------------------------------------------------- ---------------------------------------------------------------------------
@ -513,6 +544,77 @@ parallelTests = H.parallel $ do
replicate 4000 $ S.yieldM $ threadDelay 1000000) replicate 4000 $ S.yieldM $ threadDelay 1000000)
`shouldReturn` () `shouldReturn` ()
-- Each snapshot carries an independent state. Multiple parallel tasks should
-- not affect each other's state. This is especially important when we run
-- multiple tasks in a single thread.
snapshot :: (IsStream t, MonadAsync m, MonadState Int m) => t m ()
snapshot =
-- We deliberately use a replicate count 1 here, because a lower count
-- catches problems that a higher count doesn't.
S.replicateM 1 $ do
-- Even though we modify the state here it should not reflect in other
-- parallel tasks, it is local to each concurrent task.
modify (+1) >> get >>= liftIO . (`shouldSatisfy` (==1))
modify (+1) >> get >>= liftIO . (`shouldSatisfy` (==2))
snapshot1 :: (IsStream t, MonadAsync m, MonadState Int m) => t m ()
snapshot1 = S.replicateM 1000 $
modify (+1) >> get >>= liftIO . (`shouldSatisfy` (==2))
snapshot2 :: (IsStream t, MonadAsync m, MonadState Int m) => t m ()
snapshot2 = S.replicateM 1000 $
modify (+1) >> get >>= liftIO . (`shouldSatisfy` (==2))
stateComp
:: ( IsStream t
, MonadAsync m
, Semigroup (t m ())
, MonadIO (t m)
, MonadState Int m
, MonadState Int (t m)
)
=> t m ()
stateComp = do
-- Each task in a concurrent composition inherits the state and maintains
-- its own modifications to it, not affecting the parent computation.
snapshot <> (modify (+1) >> (snapshot1 <> snapshot2))
-- The above modify statement does not affect our state because that is
-- used in a parallel composition. In a serial composition it will affect
-- our state.
get >>= liftIO . (`shouldSatisfy` (== (0 :: Int)))
monadicStateSnapshot
:: ( IsStream t
, Semigroup (t (StateT Int IO) ())
, MonadIO (t (StateT Int IO))
, MonadState Int (t (StateT Int IO))
)
=> (t (StateT Int IO) () -> SerialT (StateT Int IO) ()) -> IO ()
monadicStateSnapshot t = void $ runStateT (runStream $ t stateComp) 0
stateCompOp
:: ( AsyncT (StateT Int IO) ()
-> AsyncT (StateT Int IO) ()
-> AsyncT (StateT Int IO) ()
)
-> SerialT (StateT Int IO) ()
stateCompOp op = do
-- Each task in a concurrent composition inherits the state and maintains
-- its own modifications to it, not affecting the parent computation.
asyncly (snapshot `op` (modify (+1) >> (snapshot1 `op` snapshot2)))
-- The above modify statement does not affect our state because that is
-- used in a parallel composition. In a serial composition it will affect
-- our state.
get >>= liftIO . (`shouldSatisfy` (== (0 :: Int)))
monadicStateSnapshotOp
:: ( AsyncT (StateT Int IO) ()
-> AsyncT (StateT Int IO) ()
-> AsyncT (StateT Int IO) ()
)
-> IO ()
monadicStateSnapshotOp op = void $ runStateT (runStream $ stateCompOp op) 0
takeCombined :: (Monad m, Semigroup (t m Int), Show a, Eq a, IsStream t) takeCombined :: (Monad m, Semigroup (t m Int), Show a, Eq a, IsStream t)
=> Int -> (t m Int -> SerialT IO a) -> IO () => Int -> (t m Int -> SerialT IO a) -> IO ()
takeCombined n t = do takeCombined n t = do