Add the wait primitive to discard the results

This commit is contained in:
Harendra Kumar 2017-06-26 03:14:36 +05:30
parent d4f2fec43f
commit b9c1c92dc0
4 changed files with 34 additions and 15 deletions

View File

@ -9,7 +9,8 @@
--
module Strands
( gather
( wait
, gather
, async
, sample
, threads

View File

@ -88,6 +88,8 @@ data Context = Context
-- the channel until all its children are cleaned up.
---------------------------------------------------------------------------
-- XXX setup a cleanup computation to run rather than passing all these
-- params.
-- XXX use Either parentChannel accumResults
, accumResults :: Maybe (IORef [Any])
-- ^ Accumulated results, only the top level thread context accumulates.

View File

@ -9,6 +9,8 @@ module Strands.Threads
, sync
--, react
, threads
, wait
, gather
)
where
@ -25,7 +27,7 @@ import qualified Control.Exception.Lifted as EL
import Control.Monad.Catch (MonadCatch, MonadThrow, throwM,
try)
import Control.Monad.IO.Class (MonadIO (..))
import Control.Monad.State (get, gets, modify,
import Control.Monad.State (get, gets, modify, mzero,
runStateT, when, StateT)
import Control.Monad.Trans.Class (MonadTrans (lift))
import Control.Monad.Trans.Control (MonadBaseControl, liftBaseWith)
@ -531,30 +533,23 @@ collectResult r = do
liftIO $ atomically $ writeTChan chan
(ChildResult (Right (unsafeCoerce r)))
-- | Invoked to store the result of the computation in the context and finish
-- the computation when the computation is done
finishComputation :: MonadIO m => a -> AsyncT m b
finishComputation x = AsyncT $ do
collectResult x
return Nothing
-- XXX pass a collector function and return a Traversable?
-- XXX Ideally it should be a non-empty list.
-- | Run an 'AsyncT m' computation and collect the results generated by each
-- thread of the computation in a list.
gather :: forall m a. (MonadIO m, MonadCatch m)
=> AsyncT m a -> m [a]
gather m = do
waitAsync :: forall m a b. (MonadIO m, MonadCatch m)
=> (a -> AsyncT m a) -> AsyncT m a -> m [a]
waitAsync finalizer m = do
childChan <- liftIO $ atomically newTChan
pendingRef <- liftIO $ newIORef []
resultsRef <- liftIO $ newIORef []
credit <- liftIO $ newIORef maxBound
let ctx = initContext (empty :: AsyncT m a) childChan pendingRef credit
finishComputation resultsRef
finalizer resultsRef
r <- try $ runStateT (runAsyncT $ m >>= finishComputation) ctx
r <- try $ runStateT (runAsyncT $ m >>= finalizer) ctx
case r of
Left (exc :: SomeException) -> do
@ -565,3 +560,24 @@ gather m = do
case e of
Just (exc :: SomeException) -> throwM exc
Nothing -> liftIO $ readIORef resultsRef
-- | Invoked to store the result of the computation in the context and finish
-- the computation when the computation is done
gatherResult :: MonadIO m => a -> AsyncT m a
gatherResult x = AsyncT $ do
collectResult x
return Nothing
-- | Run an 'AsyncT m' computation and collect the results generated by each
-- thread of the computation in a list.
gather :: forall m a. (MonadIO m, MonadCatch m)
=> AsyncT m a -> m [a]
gather m = waitAsync gatherResult m
-- | Run an 'AsyncT m' computation, wait for it to finish and discard the
-- results.
wait :: forall m a. (MonadIO m, MonadCatch m)
=> AsyncT m a -> m ()
wait m = do
_ <- waitAsync (const mzero) m
return ()

View File

@ -5,7 +5,7 @@ import System.IO
import Strands
main = waitAsync $ threads 3 $ do
main = wait $ threads 3 $ do
liftIO $ hSetBuffering stdout LineBuffering
mainThread <- liftIO myThreadId
liftIO $ putStrLn $ "Main thread: " ++ show mainThread