mirror of
https://github.com/composewell/streamly.git
synced 2024-09-20 07:58:27 +03:00
Fix monadic state capture and restore for concurrent tasks
This causes up to 30% regression in async stream generation benchmarks and up to 200% regression in async nested benchmarks. Mostly, due to an additional functional call that cannot be inlined.
This commit is contained in:
parent
ba5a8c44b8
commit
e810658dfe
@ -5,6 +5,12 @@
|
|||||||
* Fixed a livelock in ahead style streams. The problem manifests sometimes when
|
* Fixed a livelock in ahead style streams. The problem manifests sometimes when
|
||||||
multiple streams are merged together in ahead style and one of them is a nil
|
multiple streams are merged together in ahead style and one of them is a nil
|
||||||
stream.
|
stream.
|
||||||
|
* As per expected concurrency semantics each forked concurrent task must run
|
||||||
|
with the monadic state captured at the fork point. This release fixes a bug,
|
||||||
|
which, in some cases caused an incorrect monadic state to be used for a
|
||||||
|
concurrent action, leading to unexpected behavior when concurrent streams are
|
||||||
|
used in a stateful monad e.g. `StateT`. Particularly, this bug cannot affect
|
||||||
|
`ReaderT`.
|
||||||
|
|
||||||
## 0.5.1
|
## 0.5.1
|
||||||
|
|
||||||
|
@ -8,6 +8,7 @@
|
|||||||
{-# LANGUAGE LambdaCase #-}
|
{-# LANGUAGE LambdaCase #-}
|
||||||
{-# LANGUAGE MagicHash #-}
|
{-# LANGUAGE MagicHash #-}
|
||||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
|
{-# LANGUAGE RankNTypes #-}
|
||||||
{-# LANGUAGE ScopedTypeVariables #-}
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
{-# LANGUAGE UnboxedTuples #-}
|
{-# LANGUAGE UnboxedTuples #-}
|
||||||
|
|
||||||
@ -49,6 +50,8 @@ module Streamly.SVar
|
|||||||
-- SVar related
|
-- SVar related
|
||||||
, newAheadVar
|
, newAheadVar
|
||||||
, newParallelVar
|
, newParallelVar
|
||||||
|
, captureMonadState
|
||||||
|
, RunInIO (..)
|
||||||
|
|
||||||
, atomicModifyIORefCAS
|
, atomicModifyIORefCAS
|
||||||
, WorkerInfo (..)
|
, WorkerInfo (..)
|
||||||
@ -113,7 +116,7 @@ import Control.Exception
|
|||||||
import Control.Monad (when)
|
import Control.Monad (when)
|
||||||
import Control.Monad.Catch (MonadThrow)
|
import Control.Monad.Catch (MonadThrow)
|
||||||
import Control.Monad.IO.Class (MonadIO(..))
|
import Control.Monad.IO.Class (MonadIO(..))
|
||||||
import Control.Monad.Trans.Control (MonadBaseControl, control)
|
import Control.Monad.Trans.Control (MonadBaseControl, control, StM)
|
||||||
import Data.Atomics
|
import Data.Atomics
|
||||||
(casIORef, readForCAS, peekTicket, atomicModifyIORefCAS_,
|
(casIORef, readForCAS, peekTicket, atomicModifyIORefCAS_,
|
||||||
writeBarrier, storeLoadBarrier)
|
writeBarrier, storeLoadBarrier)
|
||||||
@ -325,7 +328,8 @@ data Limit = Unlimited | Limited Word deriving Show
|
|||||||
data SVar t m a = SVar
|
data SVar t m a = SVar
|
||||||
{
|
{
|
||||||
-- Read only state
|
-- Read only state
|
||||||
svarStyle :: SVarStyle
|
svarStyle :: SVarStyle
|
||||||
|
, svarMrun :: RunInIO m
|
||||||
|
|
||||||
-- Shared output queue (events, length)
|
-- Shared output queue (events, length)
|
||||||
-- XXX For better efficiency we can try a preallocated array type (perhaps
|
-- XXX For better efficiency we can try a preallocated array type (perhaps
|
||||||
@ -777,6 +781,18 @@ ringDoorBell sv = do
|
|||||||
-- @since 0.1.0
|
-- @since 0.1.0
|
||||||
type MonadAsync m = (MonadIO m, MonadBaseControl IO m, MonadThrow m)
|
type MonadAsync m = (MonadIO m, MonadBaseControl IO m, MonadThrow m)
|
||||||
|
|
||||||
|
-- When we run computations concurrently, we completely isolate the state of
|
||||||
|
-- the concurrent computations from the parent computation. The invariant is
|
||||||
|
-- that we should never be running two concurrent computations in the same
|
||||||
|
-- thread without using the runInIO function. Also, we should never be running
|
||||||
|
-- a concurrent computation in the parent thread, otherwise it may affect the
|
||||||
|
-- state of the parent which is against the defined semantics of concurrent
|
||||||
|
-- execution.
|
||||||
|
newtype RunInIO m = RunInIO { runInIO :: forall b. m b -> IO (StM m b) }
|
||||||
|
|
||||||
|
captureMonadState :: MonadBaseControl IO m => m (RunInIO m)
|
||||||
|
captureMonadState = control $ \run -> run (return $ RunInIO run)
|
||||||
|
|
||||||
-- Stolen from the async package. The perf improvement is modest, 2% on a
|
-- Stolen from the async package. The perf improvement is modest, 2% on a
|
||||||
-- thread heavy benchmark (parallel composition using noop computations).
|
-- thread heavy benchmark (parallel composition using noop computations).
|
||||||
-- A version of forkIO that does not include the outer exception
|
-- A version of forkIO that does not include the outer exception
|
||||||
@ -790,14 +806,15 @@ rawForkIO action = IO $ \ s ->
|
|||||||
{-# INLINE doFork #-}
|
{-# INLINE doFork #-}
|
||||||
doFork :: MonadBaseControl IO m
|
doFork :: MonadBaseControl IO m
|
||||||
=> m ()
|
=> m ()
|
||||||
|
-> RunInIO m
|
||||||
-> (SomeException -> IO ())
|
-> (SomeException -> IO ())
|
||||||
-> m ThreadId
|
-> m ThreadId
|
||||||
doFork action exHandler =
|
doFork action (RunInIO mrun) exHandler =
|
||||||
control $ \runInIO ->
|
control $ \run ->
|
||||||
mask $ \restore -> do
|
mask $ \restore -> do
|
||||||
tid <- rawForkIO $ catch (restore $ void $ runInIO action)
|
tid <- rawForkIO $ catch (restore $ void $ mrun action)
|
||||||
exHandler
|
exHandler
|
||||||
runInIO (return tid)
|
run (return tid)
|
||||||
|
|
||||||
-- XXX Can we make access to remainingWork and yieldRateInfo fields in sv
|
-- XXX Can we make access to remainingWork and yieldRateInfo fields in sv
|
||||||
-- faster, along with the fields in sv required by send?
|
-- faster, along with the fields in sv required by send?
|
||||||
@ -1288,7 +1305,8 @@ pushWorker yieldMax sv = do
|
|||||||
, workerYieldCount = cntRef
|
, workerYieldCount = cntRef
|
||||||
, workerLatencyStart = lat
|
, workerLatencyStart = lat
|
||||||
}
|
}
|
||||||
doFork (workLoop sv winfo) (handleChildException sv) >>= addThread sv
|
doFork (workLoop sv winfo) (svarMrun sv) (handleChildException sv)
|
||||||
|
>>= addThread sv
|
||||||
|
|
||||||
-- XXX we can push the workerCount modification in accountThread and use the
|
-- XXX we can push the workerCount modification in accountThread and use the
|
||||||
-- same pushWorker for Parallel case as well.
|
-- same pushWorker for Parallel case as well.
|
||||||
@ -1305,7 +1323,8 @@ pushWorkerPar
|
|||||||
pushWorkerPar sv wloop =
|
pushWorkerPar sv wloop =
|
||||||
if svarInspectMode sv
|
if svarInspectMode sv
|
||||||
then forkWithDiag
|
then forkWithDiag
|
||||||
else doFork (wloop Nothing) (handleChildException sv) >>= modifyThread sv
|
else doFork (wloop Nothing) (svarMrun sv) (handleChildException sv)
|
||||||
|
>>= modifyThread sv
|
||||||
|
|
||||||
where
|
where
|
||||||
|
|
||||||
@ -1330,7 +1349,8 @@ pushWorkerPar sv wloop =
|
|||||||
, workerLatencyStart = lat
|
, workerLatencyStart = lat
|
||||||
}
|
}
|
||||||
|
|
||||||
doFork (wloop winfo) (handleChildException sv) >>= modifyThread sv
|
doFork (wloop winfo) (svarMrun sv) (handleChildException sv)
|
||||||
|
>>= modifyThread sv
|
||||||
|
|
||||||
-- Returns:
|
-- Returns:
|
||||||
-- True: can dispatch more
|
-- True: can dispatch more
|
||||||
@ -1997,8 +2017,9 @@ getAheadSVar :: MonadAsync m
|
|||||||
-> SVar t m a
|
-> SVar t m a
|
||||||
-> Maybe WorkerInfo
|
-> Maybe WorkerInfo
|
||||||
-> m ())
|
-> m ())
|
||||||
|
-> RunInIO m
|
||||||
-> IO (SVar t m a)
|
-> IO (SVar t m a)
|
||||||
getAheadSVar st f = do
|
getAheadSVar st f mrun = do
|
||||||
outQ <- newIORef ([], 0)
|
outQ <- newIORef ([], 0)
|
||||||
-- the second component of the tuple is "Nothing" when heap is being
|
-- the second component of the tuple is "Nothing" when heap is being
|
||||||
-- cleared, "Just n" when we are expecting sequence number n to arrive
|
-- cleared, "Just n" when we are expecting sequence number n to arrive
|
||||||
@ -2036,6 +2057,7 @@ getAheadSVar st f = do
|
|||||||
, isQueueDone = isQueueDoneAhead sv q
|
, isQueueDone = isQueueDoneAhead sv q
|
||||||
, needDoorBell = wfw
|
, needDoorBell = wfw
|
||||||
, svarStyle = AheadVar
|
, svarStyle = AheadVar
|
||||||
|
, svarMrun = mrun
|
||||||
, workerCount = active
|
, workerCount = active
|
||||||
, accountThread = delThread sv
|
, accountThread = delThread sv
|
||||||
, workerStopMVar = stopMVar
|
, workerStopMVar = stopMVar
|
||||||
@ -2081,8 +2103,8 @@ getAheadSVar st f = do
|
|||||||
(xs, _) <- readIORef q
|
(xs, _) <- readIORef q
|
||||||
return $ null xs
|
return $ null xs
|
||||||
|
|
||||||
getParallelSVar :: MonadIO m => State t m a -> IO (SVar t m a)
|
getParallelSVar :: MonadIO m => State t m a -> RunInIO m -> IO (SVar t m a)
|
||||||
getParallelSVar st = do
|
getParallelSVar st mrun = do
|
||||||
outQ <- newIORef ([], 0)
|
outQ <- newIORef ([], 0)
|
||||||
outQMv <- newEmptyMVar
|
outQMv <- newEmptyMVar
|
||||||
active <- newIORef 0
|
active <- newIORef 0
|
||||||
@ -2112,6 +2134,7 @@ getParallelSVar st = do
|
|||||||
, isQueueDone = undefined
|
, isQueueDone = undefined
|
||||||
, needDoorBell = undefined
|
, needDoorBell = undefined
|
||||||
, svarStyle = ParallelVar
|
, svarStyle = ParallelVar
|
||||||
|
, svarMrun = mrun
|
||||||
, workerCount = active
|
, workerCount = active
|
||||||
, accountThread = modifyThread sv
|
, accountThread = modifyThread sv
|
||||||
, workerStopMVar = undefined
|
, workerStopMVar = undefined
|
||||||
@ -2160,12 +2183,15 @@ newAheadVar :: MonadAsync m
|
|||||||
-> m ())
|
-> m ())
|
||||||
-> m (SVar t m a)
|
-> m (SVar t m a)
|
||||||
newAheadVar st m wloop = do
|
newAheadVar st m wloop = do
|
||||||
sv <- liftIO $ getAheadSVar st wloop
|
mrun <- captureMonadState
|
||||||
|
sv <- liftIO $ getAheadSVar st wloop mrun
|
||||||
sendFirstWorker sv m
|
sendFirstWorker sv m
|
||||||
|
|
||||||
{-# INLINABLE newParallelVar #-}
|
{-# INLINABLE newParallelVar #-}
|
||||||
newParallelVar :: MonadAsync m => State t m a -> m (SVar t m a)
|
newParallelVar :: MonadAsync m => State t m a -> m (SVar t m a)
|
||||||
newParallelVar st = liftIO $ getParallelSVar st
|
newParallelVar st = do
|
||||||
|
mrun <- captureMonadState
|
||||||
|
liftIO $ getParallelSVar st mrun
|
||||||
|
|
||||||
-- XXX this errors out for Parallel/Ahead SVars
|
-- XXX this errors out for Parallel/Ahead SVars
|
||||||
-- | Write a stream to an 'SVar' in a non-blocking manner. The stream can then
|
-- | Write a stream to an 'SVar' in a non-blocking manner. The stream can then
|
||||||
|
@ -47,7 +47,9 @@ import qualified Data.Heap as H
|
|||||||
import Streamly.Streams.SVar (fromSVar)
|
import Streamly.Streams.SVar (fromSVar)
|
||||||
import Streamly.Streams.Serial (map)
|
import Streamly.Streams.Serial (map)
|
||||||
import Streamly.SVar
|
import Streamly.SVar
|
||||||
import Streamly.Streams.StreamK (IsStream(..), Stream(..))
|
import Streamly.Streams.StreamK
|
||||||
|
(IsStream(..), Stream(..), unstreamShared, unStreamIsolated,
|
||||||
|
runStreamSVar)
|
||||||
import qualified Streamly.Streams.StreamK as K
|
import qualified Streamly.Streams.StreamK as K
|
||||||
|
|
||||||
import Prelude hiding (map)
|
import Prelude hiding (map)
|
||||||
@ -296,7 +298,7 @@ processHeap q heap st sv winfo entry sno stopping = loopHeap sno entry
|
|||||||
let stop = do
|
let stop = do
|
||||||
liftIO (incrementYieldLimit sv)
|
liftIO (incrementYieldLimit sv)
|
||||||
nextHeap seqNo
|
nextHeap seqNo
|
||||||
unStream r st stop
|
runStreamSVar sv r st stop
|
||||||
(singleStreamFromHeap seqNo)
|
(singleStreamFromHeap seqNo)
|
||||||
(yieldStreamFromHeap seqNo)
|
(yieldStreamFromHeap seqNo)
|
||||||
else liftIO $ do
|
else liftIO $ do
|
||||||
@ -346,7 +348,7 @@ processWithoutToken q heap st sv winfo m seqNo = do
|
|||||||
-- we stop.
|
-- we stop.
|
||||||
toHeap AheadEntryNull
|
toHeap AheadEntryNull
|
||||||
|
|
||||||
unStream m st stop
|
runStreamSVar sv m st stop
|
||||||
(toHeap . AheadEntryPure)
|
(toHeap . AheadEntryPure)
|
||||||
(\a r -> toHeap $ AheadEntryStream $ K.cons a r)
|
(\a r -> toHeap $ AheadEntryStream $ K.cons a r)
|
||||||
|
|
||||||
@ -406,7 +408,7 @@ processWithToken q heap st sv winfo action sno = do
|
|||||||
liftIO (incrementYieldLimit sv)
|
liftIO (incrementYieldLimit sv)
|
||||||
loopWithToken (sno + 1)
|
loopWithToken (sno + 1)
|
||||||
|
|
||||||
unStream action st stop (singleOutput sno) (yieldOutput sno)
|
runStreamSVar sv action st stop (singleOutput sno) (yieldOutput sno)
|
||||||
|
|
||||||
where
|
where
|
||||||
|
|
||||||
@ -429,7 +431,7 @@ processWithToken q heap st sv winfo action sno = do
|
|||||||
let stop = do
|
let stop = do
|
||||||
liftIO (incrementYieldLimit sv)
|
liftIO (incrementYieldLimit sv)
|
||||||
loopWithToken (seqNo + 1)
|
loopWithToken (seqNo + 1)
|
||||||
unStream r st stop
|
runStreamSVar sv r st stop
|
||||||
(singleOutput seqNo)
|
(singleOutput seqNo)
|
||||||
(yieldOutput seqNo)
|
(yieldOutput seqNo)
|
||||||
else do
|
else do
|
||||||
@ -458,7 +460,7 @@ processWithToken q heap st sv winfo action sno = do
|
|||||||
let stop = do
|
let stop = do
|
||||||
liftIO (incrementYieldLimit sv)
|
liftIO (incrementYieldLimit sv)
|
||||||
loopWithToken (seqNo + 1)
|
loopWithToken (seqNo + 1)
|
||||||
unStream m st stop
|
runStreamSVar sv m st stop
|
||||||
(singleOutput seqNo)
|
(singleOutput seqNo)
|
||||||
(yieldOutput seqNo)
|
(yieldOutput seqNo)
|
||||||
else
|
else
|
||||||
@ -681,10 +683,13 @@ aheadbind m f = go m
|
|||||||
where
|
where
|
||||||
go (Stream g) =
|
go (Stream g) =
|
||||||
Stream $ \st stp sng yld ->
|
Stream $ \st stp sng yld ->
|
||||||
let run x = unStream x st stp sng yld
|
let runShared x = unstreamShared x st stp sng yld
|
||||||
single a = run $ f a
|
runIsolated x = unStreamIsolated x st stp sng yld
|
||||||
yieldk a r = run $ f a `aheadS` go r
|
|
||||||
in g (rstState st) stp single yieldk
|
single a = runIsolated $ f a
|
||||||
|
yieldk a r = runShared $
|
||||||
|
K.isolateStream (f a) `aheadS` go r
|
||||||
|
in g (rstState st) stp single yieldk
|
||||||
|
|
||||||
instance MonadAsync m => Monad (AheadT m) where
|
instance MonadAsync m => Monad (AheadT m) where
|
||||||
return = pure
|
return = pure
|
||||||
|
@ -57,7 +57,7 @@ import qualified Data.Set as S
|
|||||||
import Streamly.Streams.SVar (fromSVar)
|
import Streamly.Streams.SVar (fromSVar)
|
||||||
import Streamly.Streams.Serial (map)
|
import Streamly.Streams.Serial (map)
|
||||||
import Streamly.SVar
|
import Streamly.SVar
|
||||||
import Streamly.Streams.StreamK (IsStream(..), Stream(..), adapt)
|
import Streamly.Streams.StreamK (IsStream(..), Stream(..), adapt, runStreamSVar)
|
||||||
import qualified Streamly.Streams.StreamK as K
|
import qualified Streamly.Streams.StreamK as K
|
||||||
|
|
||||||
#include "Instances.hs"
|
#include "Instances.hs"
|
||||||
@ -82,7 +82,7 @@ workLoopLIFO q st sv winfo = run
|
|||||||
work <- dequeue
|
work <- dequeue
|
||||||
case work of
|
case work of
|
||||||
Nothing -> liftIO $ sendStop sv winfo
|
Nothing -> liftIO $ sendStop sv winfo
|
||||||
Just m -> unStream m st run single yieldk
|
Just m -> runStreamSVar sv m st run single yieldk
|
||||||
|
|
||||||
single a = do
|
single a = do
|
||||||
res <- liftIO $ sendYield sv winfo (ChildYield a)
|
res <- liftIO $ sendYield sv winfo (ChildYield a)
|
||||||
@ -91,7 +91,7 @@ workLoopLIFO q st sv winfo = run
|
|||||||
yieldk a r = do
|
yieldk a r = do
|
||||||
res <- liftIO $ sendYield sv winfo (ChildYield a)
|
res <- liftIO $ sendYield sv winfo (ChildYield a)
|
||||||
if res
|
if res
|
||||||
then unStream r st run single yieldk
|
then runStreamSVar sv r st run single yieldk
|
||||||
else liftIO $ do
|
else liftIO $ do
|
||||||
enqueueLIFO sv q r
|
enqueueLIFO sv q r
|
||||||
sendStop sv winfo
|
sendStop sv winfo
|
||||||
@ -132,7 +132,7 @@ workLoopLIFOLimited q st sv winfo = run
|
|||||||
if yieldLimitOk
|
if yieldLimitOk
|
||||||
then do
|
then do
|
||||||
let stop = liftIO (incrementYieldLimit sv) >> run
|
let stop = liftIO (incrementYieldLimit sv) >> run
|
||||||
unStream m st stop single yieldk
|
runStreamSVar sv m st stop single yieldk
|
||||||
-- Avoid any side effects, undo the yield limit decrement if we
|
-- Avoid any side effects, undo the yield limit decrement if we
|
||||||
-- never yielded anything.
|
-- never yielded anything.
|
||||||
else liftIO $ do
|
else liftIO $ do
|
||||||
@ -151,7 +151,7 @@ workLoopLIFOLimited q st sv winfo = run
|
|||||||
yieldLimitOk <- liftIO $ decrementYieldLimit sv
|
yieldLimitOk <- liftIO $ decrementYieldLimit sv
|
||||||
let stop = liftIO (incrementYieldLimit sv) >> run
|
let stop = liftIO (incrementYieldLimit sv) >> run
|
||||||
if res && yieldLimitOk
|
if res && yieldLimitOk
|
||||||
then unStream r st stop single yieldk
|
then runStreamSVar sv r st stop single yieldk
|
||||||
else liftIO $ do
|
else liftIO $ do
|
||||||
incrementYieldLimit sv
|
incrementYieldLimit sv
|
||||||
enqueueLIFO sv q r
|
enqueueLIFO sv q r
|
||||||
@ -183,7 +183,7 @@ workLoopFIFO q st sv winfo = run
|
|||||||
work <- liftIO $ tryPopR q
|
work <- liftIO $ tryPopR q
|
||||||
case work of
|
case work of
|
||||||
Nothing -> liftIO $ sendStop sv winfo
|
Nothing -> liftIO $ sendStop sv winfo
|
||||||
Just m -> unStream m st run single yieldk
|
Just m -> runStreamSVar sv m st run single yieldk
|
||||||
|
|
||||||
single a = do
|
single a = do
|
||||||
res <- liftIO $ sendYield sv winfo (ChildYield a)
|
res <- liftIO $ sendYield sv winfo (ChildYield a)
|
||||||
@ -192,7 +192,7 @@ workLoopFIFO q st sv winfo = run
|
|||||||
yieldk a r = do
|
yieldk a r = do
|
||||||
res <- liftIO $ sendYield sv winfo (ChildYield a)
|
res <- liftIO $ sendYield sv winfo (ChildYield a)
|
||||||
if res
|
if res
|
||||||
then unStream r st run single yieldk
|
then runStreamSVar sv r st run single yieldk
|
||||||
else liftIO $ do
|
else liftIO $ do
|
||||||
enqueueFIFO sv q r
|
enqueueFIFO sv q r
|
||||||
sendStop sv winfo
|
sendStop sv winfo
|
||||||
@ -218,7 +218,7 @@ workLoopFIFOLimited q st sv winfo = run
|
|||||||
if yieldLimitOk
|
if yieldLimitOk
|
||||||
then do
|
then do
|
||||||
let stop = liftIO (incrementYieldLimit sv) >> run
|
let stop = liftIO (incrementYieldLimit sv) >> run
|
||||||
unStream m st stop single yieldk
|
runStreamSVar sv m st stop single yieldk
|
||||||
else liftIO $ do
|
else liftIO $ do
|
||||||
enqueueFIFO sv q m
|
enqueueFIFO sv q m
|
||||||
incrementYieldLimit sv
|
incrementYieldLimit sv
|
||||||
@ -233,7 +233,7 @@ workLoopFIFOLimited q st sv winfo = run
|
|||||||
yieldLimitOk <- liftIO $ decrementYieldLimit sv
|
yieldLimitOk <- liftIO $ decrementYieldLimit sv
|
||||||
let stop = liftIO (incrementYieldLimit sv) >> run
|
let stop = liftIO (incrementYieldLimit sv) >> run
|
||||||
if res && yieldLimitOk
|
if res && yieldLimitOk
|
||||||
then unStream r st stop single yieldk
|
then runStreamSVar sv r st stop single yieldk
|
||||||
else liftIO $ do
|
else liftIO $ do
|
||||||
incrementYieldLimit sv
|
incrementYieldLimit sv
|
||||||
enqueueFIFO sv q r
|
enqueueFIFO sv q r
|
||||||
@ -249,8 +249,8 @@ workLoopFIFOLimited q st sv winfo = run
|
|||||||
-- than 10%. Need to investigate what the root cause is.
|
-- than 10%. Need to investigate what the root cause is.
|
||||||
-- Interestingly, the same thing does not make any difference for Ahead.
|
-- Interestingly, the same thing does not make any difference for Ahead.
|
||||||
getLifoSVar :: forall m a. MonadAsync m
|
getLifoSVar :: forall m a. MonadAsync m
|
||||||
=> State Stream m a -> IO (SVar Stream m a)
|
=> State Stream m a -> RunInIO m -> IO (SVar Stream m a)
|
||||||
getLifoSVar st = do
|
getLifoSVar st mrun = do
|
||||||
outQ <- newIORef ([], 0)
|
outQ <- newIORef ([], 0)
|
||||||
outQMv <- newEmptyMVar
|
outQMv <- newEmptyMVar
|
||||||
active <- newIORef 0
|
active <- newIORef 0
|
||||||
@ -303,6 +303,7 @@ getLifoSVar st = do
|
|||||||
, isQueueDone = workDone sv
|
, isQueueDone = workDone sv
|
||||||
, needDoorBell = wfw
|
, needDoorBell = wfw
|
||||||
, svarStyle = AsyncVar
|
, svarStyle = AsyncVar
|
||||||
|
, svarMrun = mrun
|
||||||
, workerCount = active
|
, workerCount = active
|
||||||
, accountThread = delThread sv
|
, accountThread = delThread sv
|
||||||
, workerStopMVar = undefined
|
, workerStopMVar = undefined
|
||||||
@ -339,8 +340,8 @@ getLifoSVar st = do
|
|||||||
in return sv
|
in return sv
|
||||||
|
|
||||||
getFifoSVar :: forall m a. MonadAsync m
|
getFifoSVar :: forall m a. MonadAsync m
|
||||||
=> State Stream m a -> IO (SVar Stream m a)
|
=> State Stream m a -> RunInIO m -> IO (SVar Stream m a)
|
||||||
getFifoSVar st = do
|
getFifoSVar st mrun = do
|
||||||
outQ <- newIORef ([], 0)
|
outQ <- newIORef ([], 0)
|
||||||
outQMv <- newEmptyMVar
|
outQMv <- newEmptyMVar
|
||||||
active <- newIORef 0
|
active <- newIORef 0
|
||||||
@ -392,6 +393,7 @@ getFifoSVar st = do
|
|||||||
, isQueueDone = workDone sv
|
, isQueueDone = workDone sv
|
||||||
, needDoorBell = wfw
|
, needDoorBell = wfw
|
||||||
, svarStyle = WAsyncVar
|
, svarStyle = WAsyncVar
|
||||||
|
, svarMrun = mrun
|
||||||
, workerCount = active
|
, workerCount = active
|
||||||
, accountThread = delThread sv
|
, accountThread = delThread sv
|
||||||
, workerStopMVar = undefined
|
, workerStopMVar = undefined
|
||||||
@ -431,7 +433,8 @@ getFifoSVar st = do
|
|||||||
newAsyncVar :: MonadAsync m
|
newAsyncVar :: MonadAsync m
|
||||||
=> State Stream m a -> Stream m a -> m (SVar Stream m a)
|
=> State Stream m a -> Stream m a -> m (SVar Stream m a)
|
||||||
newAsyncVar st m = do
|
newAsyncVar st m = do
|
||||||
sv <- liftIO $ getLifoSVar st
|
mrun <- captureMonadState
|
||||||
|
sv <- liftIO $ getLifoSVar st mrun
|
||||||
sendFirstWorker sv m
|
sendFirstWorker sv m
|
||||||
|
|
||||||
-- XXX Get rid of this?
|
-- XXX Get rid of this?
|
||||||
@ -455,7 +458,8 @@ mkAsync' st m = fmap fromSVar (newAsyncVar st (toStream m))
|
|||||||
newWAsyncVar :: MonadAsync m
|
newWAsyncVar :: MonadAsync m
|
||||||
=> State Stream m a -> Stream m a -> m (SVar Stream m a)
|
=> State Stream m a -> Stream m a -> m (SVar Stream m a)
|
||||||
newWAsyncVar st m = do
|
newWAsyncVar st m = do
|
||||||
sv <- liftIO $ getFifoSVar st
|
mrun <- captureMonadState
|
||||||
|
sv <- liftIO $ getFifoSVar st mrun
|
||||||
sendFirstWorker sv m
|
sendFirstWorker sv m
|
||||||
|
|
||||||
------------------------------------------------------------------------------
|
------------------------------------------------------------------------------
|
||||||
|
@ -67,6 +67,7 @@ runOne st m winfo = unStream m st stop single yieldk
|
|||||||
where
|
where
|
||||||
|
|
||||||
sv = fromJust $ streamVar st
|
sv = fromJust $ streamVar st
|
||||||
|
mrun = runInIO $ svarMrun sv
|
||||||
|
|
||||||
withLimitCheck action = do
|
withLimitCheck action = do
|
||||||
yieldLimitOk <- liftIO $ decrementYieldLimitPost sv
|
yieldLimitOk <- liftIO $ decrementYieldLimitPost sv
|
||||||
@ -82,7 +83,8 @@ runOne st m winfo = unStream m st stop single yieldk
|
|||||||
-- queue and queue it back on that and exit the thread when the outputQueue
|
-- queue and queue it back on that and exit the thread when the outputQueue
|
||||||
-- overflows. Parallel is dangerous because it can accumulate unbounded
|
-- overflows. Parallel is dangerous because it can accumulate unbounded
|
||||||
-- output in the buffer.
|
-- output in the buffer.
|
||||||
yieldk a r = void (sendit a) >> withLimitCheck (runOne st r winfo)
|
yieldk a r = void (sendit a)
|
||||||
|
>> withLimitCheck (void $ liftIO $ mrun $ runOne st r winfo)
|
||||||
|
|
||||||
{-# NOINLINE forkSVarPar #-}
|
{-# NOINLINE forkSVarPar #-}
|
||||||
forkSVarPar :: MonadAsync m => Stream m a -> Stream m a -> Stream m a
|
forkSVarPar :: MonadAsync m => Stream m a -> Stream m a -> Stream m a
|
||||||
|
@ -33,6 +33,10 @@ module Streamly.Streams.StreamK
|
|||||||
|
|
||||||
-- * The stream type
|
-- * The stream type
|
||||||
, Stream (..)
|
, Stream (..)
|
||||||
|
, unStreamIsolated
|
||||||
|
, isolateStream
|
||||||
|
, unstreamShared
|
||||||
|
, runStreamSVar
|
||||||
|
|
||||||
-- * Construction
|
-- * Construction
|
||||||
, mkStream
|
, mkStream
|
||||||
@ -140,6 +144,7 @@ module Streamly.Streams.StreamK
|
|||||||
where
|
where
|
||||||
|
|
||||||
import Control.Monad (void)
|
import Control.Monad (void)
|
||||||
|
import Control.Monad.IO.Class (MonadIO(liftIO))
|
||||||
import Control.Monad.Reader.Class (MonadReader(..))
|
import Control.Monad.Reader.Class (MonadReader(..))
|
||||||
import Control.Monad.Trans.Class (MonadTrans(lift))
|
import Control.Monad.Trans.Class (MonadTrans(lift))
|
||||||
import Data.Semigroup (Semigroup(..))
|
import Data.Semigroup (Semigroup(..))
|
||||||
@ -183,6 +188,51 @@ newtype Stream m a =
|
|||||||
-> m r
|
-> m r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
-- XXX make this the default "unStream"
|
||||||
|
-- | unwraps the Stream type producing the stream function that can be run with
|
||||||
|
-- continuations.
|
||||||
|
{-# INLINE unStreamIsolated #-}
|
||||||
|
unStreamIsolated ::
|
||||||
|
Stream m a
|
||||||
|
-> State Stream m a -- state
|
||||||
|
-> m r -- stop
|
||||||
|
-> (a -> m r) -- singleton
|
||||||
|
-> (a -> Stream m a -> m r) -- yield
|
||||||
|
-> m r
|
||||||
|
unStreamIsolated x st = unStream x (rstState st)
|
||||||
|
|
||||||
|
{-# INLINE isolateStream #-}
|
||||||
|
isolateStream :: Stream m a -> Stream m a
|
||||||
|
isolateStream x = Stream $ \st stp sng yld ->
|
||||||
|
unStreamIsolated x st stp sng yld
|
||||||
|
|
||||||
|
-- | Like unstream, but passes a shared SVar across continuations.
|
||||||
|
{-# INLINE unstreamShared #-}
|
||||||
|
unstreamShared ::
|
||||||
|
Stream m a
|
||||||
|
-> State Stream m a -- state
|
||||||
|
-> m r -- stop
|
||||||
|
-> (a -> m r) -- singleton
|
||||||
|
-> (a -> Stream m a -> m r) -- yield
|
||||||
|
-> m r
|
||||||
|
unstreamShared = unStream
|
||||||
|
|
||||||
|
-- Run the stream using a run function associated with the SVar that runs the
|
||||||
|
-- streams with a captured snapshot of the monadic state.
|
||||||
|
{-# INLINE runStreamSVar #-}
|
||||||
|
runStreamSVar
|
||||||
|
:: MonadIO m
|
||||||
|
=> SVar Stream m a
|
||||||
|
-> Stream m a
|
||||||
|
-> State Stream m a -- state
|
||||||
|
-> m r -- stop
|
||||||
|
-> (a -> m r) -- singleton
|
||||||
|
-> (a -> Stream m a -> m r) -- yield
|
||||||
|
-> m ()
|
||||||
|
runStreamSVar sv m st stp sng yld =
|
||||||
|
let mrun = runInIO $ svarMrun sv
|
||||||
|
in void $ liftIO $ mrun $ unStream m st stp sng yld
|
||||||
|
|
||||||
------------------------------------------------------------------------------
|
------------------------------------------------------------------------------
|
||||||
-- Types that can behave as a Stream
|
-- Types that can behave as a Stream
|
||||||
------------------------------------------------------------------------------
|
------------------------------------------------------------------------------
|
||||||
@ -969,10 +1019,12 @@ bindWith par m f = go m
|
|||||||
where
|
where
|
||||||
go (Stream g) =
|
go (Stream g) =
|
||||||
Stream $ \st stp sng yld ->
|
Stream $ \st stp sng yld ->
|
||||||
let run x = unStream x st stp sng yld
|
let runShared x = unstreamShared x st stp sng yld
|
||||||
single a = run $ f a
|
runIsolated x = unStreamIsolated x st stp sng yld
|
||||||
yieldk a r = run $ f a `par` go r
|
|
||||||
in g (rstState st) stp single yieldk
|
single a = runIsolated $ f a
|
||||||
|
yieldk a r = runShared $ isolateStream (f a) `par` go r
|
||||||
|
in g (rstState st) stp single yieldk
|
||||||
|
|
||||||
------------------------------------------------------------------------------
|
------------------------------------------------------------------------------
|
||||||
-- Alternative & MonadPlus
|
-- Alternative & MonadPlus
|
||||||
|
102
test/Main.hs
102
test/Main.hs
@ -7,8 +7,11 @@ module Main (main) where
|
|||||||
|
|
||||||
import Control.Concurrent (threadDelay)
|
import Control.Concurrent (threadDelay)
|
||||||
import Control.Exception (Exception, try, ErrorCall(..), catch, throw)
|
import Control.Exception (Exception, try, ErrorCall(..), catch, throw)
|
||||||
|
import Control.Monad (void)
|
||||||
import Control.Monad.Catch (throwM, MonadThrow)
|
import Control.Monad.Catch (throwM, MonadThrow)
|
||||||
import Control.Monad.Error.Class (throwError, MonadError)
|
import Control.Monad.Error.Class (throwError, MonadError)
|
||||||
|
import Control.Monad.IO.Class (MonadIO(liftIO))
|
||||||
|
import Control.Monad.State (MonadState, get, modify, runStateT, StateT)
|
||||||
import Control.Monad.Trans.Except (runExceptT, ExceptT)
|
import Control.Monad.Trans.Except (runExceptT, ExceptT)
|
||||||
import Data.Foldable (forM_, fold)
|
import Data.Foldable (forM_, fold)
|
||||||
import Data.List (sort)
|
import Data.List (sort)
|
||||||
@ -495,6 +498,34 @@ parallelTests = H.parallel $ do
|
|||||||
it "foldlM' is strict enough" (checkFoldMStrictness foldlM'StrictCheck)
|
it "foldlM' is strict enough" (checkFoldMStrictness foldlM'StrictCheck)
|
||||||
it "scanlM' is strict enough" (checkScanlMStrictness scanlM'StrictCheck)
|
it "scanlM' is strict enough" (checkScanlMStrictness scanlM'StrictCheck)
|
||||||
|
|
||||||
|
---------------------------------------------------------------------------
|
||||||
|
-- Monadic state snapshot in concurrent tasks
|
||||||
|
---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
it "asyncly maintains independent states in concurrent tasks"
|
||||||
|
(monadicStateSnapshot asyncly)
|
||||||
|
it "asyncly limited maintains independent states in concurrent tasks"
|
||||||
|
(monadicStateSnapshot (asyncly . S.take 10000))
|
||||||
|
it "wAsyncly maintains independent states in concurrent tasks"
|
||||||
|
(monadicStateSnapshot wAsyncly)
|
||||||
|
it "wAsyncly limited maintains independent states in concurrent tasks"
|
||||||
|
(monadicStateSnapshot (wAsyncly . S.take 10000))
|
||||||
|
it "aheadly maintains independent states in concurrent tasks"
|
||||||
|
(monadicStateSnapshot aheadly)
|
||||||
|
it "aheadly limited maintains independent states in concurrent tasks"
|
||||||
|
(monadicStateSnapshot (aheadly . S.take 10000))
|
||||||
|
it "parallely maintains independent states in concurrent tasks"
|
||||||
|
(monadicStateSnapshot parallely)
|
||||||
|
|
||||||
|
it "async maintains independent states in concurrent tasks"
|
||||||
|
(monadicStateSnapshotOp async)
|
||||||
|
it "ahead maintains independent states in concurrent tasks"
|
||||||
|
(monadicStateSnapshotOp ahead)
|
||||||
|
it "wAsync maintains independent states in concurrent tasks"
|
||||||
|
(monadicStateSnapshotOp wAsync)
|
||||||
|
it "parallel maintains independent states in concurrent tasks"
|
||||||
|
(monadicStateSnapshotOp Streamly.parallel)
|
||||||
|
|
||||||
---------------------------------------------------------------------------
|
---------------------------------------------------------------------------
|
||||||
-- Slower tests are at the end
|
-- Slower tests are at the end
|
||||||
---------------------------------------------------------------------------
|
---------------------------------------------------------------------------
|
||||||
@ -513,6 +544,77 @@ parallelTests = H.parallel $ do
|
|||||||
replicate 4000 $ S.yieldM $ threadDelay 1000000)
|
replicate 4000 $ S.yieldM $ threadDelay 1000000)
|
||||||
`shouldReturn` ()
|
`shouldReturn` ()
|
||||||
|
|
||||||
|
-- Each snapshot carries an independent state. Multiple parallel tasks should
|
||||||
|
-- not affect each other's state. This is especially important when we run
|
||||||
|
-- multiple tasks in a single thread.
|
||||||
|
snapshot :: (IsStream t, MonadAsync m, MonadState Int m) => t m ()
|
||||||
|
snapshot =
|
||||||
|
-- We deliberately use a replicate count 1 here, because a lower count
|
||||||
|
-- catches problems that a higher count doesn't.
|
||||||
|
S.replicateM 1 $ do
|
||||||
|
-- Even though we modify the state here it should not reflect in other
|
||||||
|
-- parallel tasks, it is local to each concurrent task.
|
||||||
|
modify (+1) >> get >>= liftIO . (`shouldSatisfy` (==1))
|
||||||
|
modify (+1) >> get >>= liftIO . (`shouldSatisfy` (==2))
|
||||||
|
|
||||||
|
snapshot1 :: (IsStream t, MonadAsync m, MonadState Int m) => t m ()
|
||||||
|
snapshot1 = S.replicateM 1000 $
|
||||||
|
modify (+1) >> get >>= liftIO . (`shouldSatisfy` (==2))
|
||||||
|
|
||||||
|
snapshot2 :: (IsStream t, MonadAsync m, MonadState Int m) => t m ()
|
||||||
|
snapshot2 = S.replicateM 1000 $
|
||||||
|
modify (+1) >> get >>= liftIO . (`shouldSatisfy` (==2))
|
||||||
|
|
||||||
|
stateComp
|
||||||
|
:: ( IsStream t
|
||||||
|
, MonadAsync m
|
||||||
|
, Semigroup (t m ())
|
||||||
|
, MonadIO (t m)
|
||||||
|
, MonadState Int m
|
||||||
|
, MonadState Int (t m)
|
||||||
|
)
|
||||||
|
=> t m ()
|
||||||
|
stateComp = do
|
||||||
|
-- Each task in a concurrent composition inherits the state and maintains
|
||||||
|
-- its own modifications to it, not affecting the parent computation.
|
||||||
|
snapshot <> (modify (+1) >> (snapshot1 <> snapshot2))
|
||||||
|
-- The above modify statement does not affect our state because that is
|
||||||
|
-- used in a parallel composition. In a serial composition it will affect
|
||||||
|
-- our state.
|
||||||
|
get >>= liftIO . (`shouldSatisfy` (== (0 :: Int)))
|
||||||
|
|
||||||
|
monadicStateSnapshot
|
||||||
|
:: ( IsStream t
|
||||||
|
, Semigroup (t (StateT Int IO) ())
|
||||||
|
, MonadIO (t (StateT Int IO))
|
||||||
|
, MonadState Int (t (StateT Int IO))
|
||||||
|
)
|
||||||
|
=> (t (StateT Int IO) () -> SerialT (StateT Int IO) ()) -> IO ()
|
||||||
|
monadicStateSnapshot t = void $ runStateT (runStream $ t stateComp) 0
|
||||||
|
|
||||||
|
stateCompOp
|
||||||
|
:: ( AsyncT (StateT Int IO) ()
|
||||||
|
-> AsyncT (StateT Int IO) ()
|
||||||
|
-> AsyncT (StateT Int IO) ()
|
||||||
|
)
|
||||||
|
-> SerialT (StateT Int IO) ()
|
||||||
|
stateCompOp op = do
|
||||||
|
-- Each task in a concurrent composition inherits the state and maintains
|
||||||
|
-- its own modifications to it, not affecting the parent computation.
|
||||||
|
asyncly (snapshot `op` (modify (+1) >> (snapshot1 `op` snapshot2)))
|
||||||
|
-- The above modify statement does not affect our state because that is
|
||||||
|
-- used in a parallel composition. In a serial composition it will affect
|
||||||
|
-- our state.
|
||||||
|
get >>= liftIO . (`shouldSatisfy` (== (0 :: Int)))
|
||||||
|
|
||||||
|
monadicStateSnapshotOp
|
||||||
|
:: ( AsyncT (StateT Int IO) ()
|
||||||
|
-> AsyncT (StateT Int IO) ()
|
||||||
|
-> AsyncT (StateT Int IO) ()
|
||||||
|
)
|
||||||
|
-> IO ()
|
||||||
|
monadicStateSnapshotOp op = void $ runStateT (runStream $ stateCompOp op) 0
|
||||||
|
|
||||||
takeCombined :: (Monad m, Semigroup (t m Int), Show a, Eq a, IsStream t)
|
takeCombined :: (Monad m, Semigroup (t m Int), Show a, Eq a, IsStream t)
|
||||||
=> Int -> (t m Int -> SerialT IO a) -> IO ()
|
=> Int -> (t m Int -> SerialT IO a) -> IO ()
|
||||||
takeCombined n t = do
|
takeCombined n t = do
|
||||||
|
Loading…
Reference in New Issue
Block a user