Refactor/fix returning results from child threads

This commit is contained in:
Harendra Kumar 2017-06-25 21:23:14 +05:30
parent f49eccb296
commit 8a0fdb3abd
4 changed files with 118 additions and 157 deletions

View File

@ -12,6 +12,10 @@
module Strands.AsyncT
( AsyncT (..)
, waitAsync
, processOneEvent
, drainChildren
, waitForChildren
, getCtxResultDest
, (<**)
, onNothing
, dbg
@ -21,7 +25,7 @@ where
import Control.Applicative (Alternative (..))
import Control.Concurrent (ThreadId, killThread)
import Control.Concurrent.STM (TChan, atomically, newTChan,
readTChan)
readTChan, writeTChan)
import Control.Monad.Base (MonadBase (..), liftBaseDefault)
import Control.Monad.Catch (MonadCatch, MonadThrow, try,
throwM, SomeException)
@ -34,14 +38,15 @@ import Control.Monad.Trans.Control (ComposeSt, MonadBaseControl (..),
defaultLiftBaseWith,
defaultRestoreM)
import Data.Dynamic (Typeable)
import Data.IORef (IORef, newIORef, readIORef,
import Data.IORef (IORef, modifyIORef, newIORef, readIORef,
writeIORef)
import Data.List (delete)
import Data.Maybe (isJust, isNothing)
import Data.Maybe (fromJust, isJust, isNothing)
import Unsafe.Coerce (unsafeCoerce)
import Strands.Context
-- import Debug.Trace (traceM)
--import Debug.Trace (traceM)
newtype AsyncT m a = AsyncT { runAsyncT :: StateT Context m (Maybe a) }
@ -156,43 +161,82 @@ instance MonadThrow m => MonadThrow (AsyncT m) where
-- Thread management
------------------------------------------------------------------------------
drainChildren :: (MonadIO m, MonadThrow m)
=> TChan (ChildEvent a) -> [ThreadId] -> [a] -> m [a]
drainChildren chan pending res =
if pending == []
then return res
else do
ev <- liftIO $ atomically $ readTChan chan
case ev of
ChildDone tid er -> do
dbg $ "drainChildrenTop ChildDone, tid: " ++ show tid
case er of
Left e -> throwM e
Right r ->
drainChildren chan (delete tid pending) (r ++ res)
PassOnResult er -> do
dbg $ "drainChildrenTop PassOnResult"
case er of
Left e -> throwM e
Right r -> drainChildren chan pending (r ++ res)
-- XXX We are using unbounded channels so this will not block on writing to
-- pchan. We can use bounded channels to throttle the creation of threads based
-- on consumption rate.
processOneEvent :: MonadIO m
=> ChildEvent a
-> Either (TChan (ChildEvent a)) (IORef [a])
-> [ThreadId]
-> m ([ThreadId], Maybe SomeException)
processOneEvent ev dest pending = do
-- Collect results unless we have already encountered an exception.
case ev of
ChildDone _ (Just e) -> handleException e
ChildResult (Left e) -> handleException e
ChildDone tid Nothing -> return (delete tid pending, Nothing)
ChildResult (Right r) -> do
case dest of
Left chan -> liftIO $ atomically $ writeTChan chan ev
Right ref ->
liftIO $ modifyIORef ref $ \rs -> unsafeCoerce r : rs
return (pending, Nothing)
where
waitForChildren :: (MonadIO m, MonadThrow m)
=> TChan (ChildEvent a) -> IORef [ThreadId] -> [a] -> m [a]
waitForChildren chan pendingRef results = do
handleException e = do
liftIO $ mapM_ killThread pending
return (pending, Just e)
drainChildren :: MonadIO m
=> Either (TChan (ChildEvent a)) (IORef [a])
-> TChan (ChildEvent a)
-> [ThreadId]
-> m ([ThreadId], Maybe SomeException)
drainChildren dest cchan pending =
case pending of
[] -> return (pending, Nothing)
_ -> do
ev <- liftIO $ atomically $ readTChan cchan
(p, e) <- processOneEvent ev dest pending
maybe (drainChildren dest cchan p) (const $ return (p, e)) e
waitForChildren :: MonadIO m => Context -> m (Maybe SomeException)
waitForChildren ctx = do
let pendingRef = pendingThreads ctx
pending <- liftIO $ readIORef pendingRef
r <- drainChildren chan pending results
liftIO $ writeIORef pendingRef []
return r
(p, e) <- drainChildren (getCtxResultDest ctx) (childChannel ctx) pending
liftIO $ writeIORef pendingRef p
return e
getCtxResultDest :: Context -> Either (TChan (ChildEvent a)) (IORef [a])
getCtxResultDest ctx =
maybe (Right $ unsafeCoerce $ fromJust $ accumResults ctx)
(Left . unsafeCoerce) (parentChannel ctx)
------------------------------------------------------------------------------
-- Running the monad
------------------------------------------------------------------------------
collectResult :: MonadIO m => a -> StateT Context m ()
collectResult r = do
ctx <- get
case parentChannel ctx of
Nothing -> do
let ref = fromJust $ accumResults ctx
liftIO $ modifyIORef ref $ \rs -> unsafeCoerce r : rs
Just chan -> do
-- XXX can we pass the result directly to the root thread
-- instead of passing through all the parents? We can let the
-- parent go away and handle the ChildDone events as well in
-- the root thread.
liftIO $ atomically $ writeTChan chan
(ChildResult (Right (unsafeCoerce r)))
-- | Invoked to store the result of the computation in the context and finish
-- the computation when the computation is done
finishComputation :: Monad m => a -> AsyncT m b
finishComputation :: MonadIO m => a -> AsyncT m b
finishComputation x = AsyncT $ do
contextSaveResult x
collectResult x
return Nothing
-- XXX pass a collector function and return a Traversable.
@ -204,21 +248,23 @@ waitAsync :: forall m a. (MonadIO m, MonadCatch m)
waitAsync m = do
childChan <- liftIO $ atomically newTChan
pendingRef <- liftIO $ newIORef []
resultsRef <- liftIO $ newIORef []
credit <- liftIO $ newIORef maxBound
-- XXX this should be moved to Context.hs and then we can make m
-- existential and remove the unsafeCoerces
r <- try $ runStateT (runAsyncT m) $ initContext
(empty :: AsyncT m a) childChan pendingRef credit
finishComputation
let ctx = initContext (empty :: AsyncT m a) childChan pendingRef credit
finishComputation resultsRef
xs <- case r of
r <- try $ runStateT (runAsyncT $ m >>= finishComputation) ctx
case r of
Left (exc :: SomeException) -> do
liftIO $ readIORef pendingRef >>= mapM_ killThread
throwM exc
Right (Nothing, ctx) -> return $ contextGetResult ctx
Right ((Just x), ctx) -> return $ x : contextGetResult ctx
waitForChildren childChan pendingRef xs
Right _ -> do
e <- waitForChildren ctx
case e of
Just (exc :: SomeException) -> throwM exc
Nothing -> liftIO $ readIORef resultsRef
------------------------------------------------------------------------------
-- * Extensible State: Session Data Management

View File

@ -9,8 +9,6 @@ module Strands.Context
, saveContext
, restoreContext
, resumeContext
, contextSaveResult
, contextGetResult
, setContextMailBox
, takeContextMailBox
, Location(..)
@ -40,8 +38,8 @@ import GHC.Prim (Any)
------------------------------------------------------------------------------
data ChildEvent a =
ChildDone ThreadId (Either SomeException [a]) -- A child is finished
| PassOnResult (Either SomeException [a]) -- Pass on the result of a child
ChildDone ThreadId (Maybe SomeException) -- A child is finished
| ChildResult (Either SomeException a) -- Pass on the result of a child
-- | Describes the context of a computation.
data Context = Context
@ -63,7 +61,12 @@ data Context = Context
, mfData :: M.Map TypeRep Any -- untyped, type coerced
-- ^ State data accessed with get or put operations
-- XXX use Either parentChannel accumResults
, accumResults :: Maybe (IORef [Any])
-- ^ Accumulated results, only the top level thread context accumulates.
-- Child threads just pass on results via the parent channels.
-- XXX we can pass this at the time of fork rather than keeping it here.
, parentChannel :: Maybe (TChan (ChildEvent Any))
-- ^ Our parent thread's channel to communicate to when we die
@ -104,8 +107,6 @@ data Context = Context
--
, threadCredit :: IORef Int
-- ^ How many more threads are allowed to be created?
, accumResults :: [Any]
-- ^ Accumulated results when running synchronously
} deriving Typeable
initContext
@ -114,8 +115,9 @@ initContext
-> IORef [ThreadId]
-> IORef Int
-> (b -> m a)
-> IORef [a]
-> Context
initContext x childChan pending credit finalizer =
initContext x childChan pending credit finalizer results =
Context { mailBox = Nothing
, currentm = unsafeCoerce x
, fstack = [unsafeCoerce finalizer]
@ -125,7 +127,7 @@ initContext x childChan pending credit finalizer =
, childChannel = unsafeCoerce childChan
, pendingThreads = pending
, threadCredit = credit
, accumResults = [] }
, accumResults = Just (unsafeCoerce results) }
------------------------------------------------------------------------------
-- Where is the computation running?
@ -199,14 +201,6 @@ resumeContext Context { currentm = m
composefStack [] _ = error "Bug: this should never be reached"
composefStack (f:ff) x = f x >>= composefStack ff
contextSaveResult :: Monad m => a -> StateT Context m ()
contextSaveResult r =
modify $ \ Context {accumResults = rs, ..} ->
Context {accumResults = unsafeCoerce r : rs, ..}
contextGetResult :: Context -> [a]
contextGetResult ctx = unsafeCoerce $ accumResults ctx
setContextMailBox :: Context -> a -> Context
setContextMailBox ctx mbdata = ctx { mailBox = Just $ unsafeCoerce mbdata }

View File

@ -69,74 +69,20 @@ import Strands.Context
-- | Continue execution of the closure that we were executing when we migrated
-- to a new thread.
resume :: MonadIO m => Context -> StateT Context m ()
resume ctx = do
runContext :: MonadIO m => Context -> StateT Context m ()
runContext ctx = do
-- XXX rename to buildContext or buildState?
let s = runAsyncT (resumeContext ctx)
-- The returned value is always 'Nothing', we just discard it
(_, c) <- lift $ runStateT s ctx
-- XXX can we pass the result directly to the root thread instead of
-- passing through all the parents? We can let the parent go away and
-- handle the ChildDone events as well in the root thread.
case parentChannel c of
Nothing -> modify $ \x ->
x { accumResults = accumResults c ++ accumResults x }
Just chan -> do
let r = accumResults c
when (length r /= 0) $
-- there is only one result in case of a non-root thread
-- XXX change the return type to 'a' instead of '[a]'
liftIO $ atomically $ writeTChan chan (PassOnResult (Right r))
_ <- lift $ runStateT s ctx
return ()
------------------------------------------------------------------------------
-- Thread Management (creation, reaping and killing)
------------------------------------------------------------------------------
-- XXX We are using unbounded channels so this will not block on writing to
-- pchan. We can use bounded channels to throttle the creation of threads based
-- on consumption rate.
processOneEvent :: MonadIO m
=> ChildEvent a
-> TChan (ChildEvent a)
-> [ThreadId]
-> Maybe SomeException
-> m ([ThreadId], Maybe SomeException)
processOneEvent ev pchan pending exc = do
e <- case exc of
Nothing ->
case ev of
ChildDone tid res -> do
dbg $ "processOneEvent ChildDone: " ++ show tid
handlePass res
PassOnResult res -> do
dbg $ "processOneEvent PassOnResult"
handlePass res
Just _ -> return exc
let p = case ev of
ChildDone tid _ -> delete tid pending
_ -> pending
return (p, e)
where
handlePass :: MonadIO m
=> Either SomeException [a] -> m (Maybe SomeException)
handlePass res =
case res of
Left e -> do
dbg $ "handlePass: caught exception"
liftIO $ mapM_ killThread pending
return (Just e)
Right [] -> return Nothing
Right _ -> do
liftIO $ atomically $ writeTChan pchan
(PassOnResult (unsafeCoerce res))
return Nothing
tryReclaimZombies :: (MonadIO m, MonadThrow m) => Context -> m ()
tryReclaimZombies ctx = do
let pchan = fromJust (parentChannel ctx)
let dest = getCtxResultDest ctx
cchan = childChannel ctx
pendingRef = pendingThreads ctx
@ -148,7 +94,7 @@ tryReclaimZombies ctx = do
case mev of
Nothing -> return ()
Just ev -> do
(p, e) <- processOneEvent ev pchan pending Nothing
(p, e) <- processOneEvent ev dest pending
liftIO $ writeIORef pendingRef p
maybe (return ()) throwM e
tryReclaimZombies ctx
@ -157,40 +103,16 @@ waitForOneEvent :: (MonadIO m, MonadThrow m) => Context -> m ()
waitForOneEvent ctx = do
-- XXX assert pending must have at least one element
-- assert that the tid is found in our list
let pchan = fromJust (parentChannel ctx)
let dest = getCtxResultDest ctx
cchan = childChannel ctx
pendingRef = pendingThreads ctx
ev <- liftIO $ atomically $ readTChan cchan
pending <- liftIO $ readIORef pendingRef
(p, e) <- processOneEvent ev pchan pending Nothing
(p, e) <- processOneEvent ev dest pending
liftIO $ writeIORef pendingRef p
maybe (return ()) throwM e
drainChildren :: MonadIO m
=> TChan (ChildEvent a)
-> TChan (ChildEvent a)
-> [ThreadId]
-> Maybe SomeException
-> m (Maybe SomeException)
drainChildren pchan cchan pending exc =
if pending == []
then return exc
else do
ev <- liftIO $ atomically $ readTChan cchan
(p, e) <- processOneEvent ev pchan pending exc
drainChildren pchan cchan p e
waitForChildren :: MonadIO m
=> Context -> Maybe SomeException -> m (Maybe SomeException)
waitForChildren ctx exc = do
let pendingRef = pendingThreads ctx
pchan = fromJust (parentChannel ctx)
pending <- liftIO $ readIORef pendingRef
e <- drainChildren pchan (childChannel ctx) pending exc
liftIO $ writeIORef pendingRef []
return e
-- | kill all the child threads associated with the continuation context
killChildren :: Context -> IO ()
killChildren ctx = do
@ -233,7 +155,8 @@ forkFinally1 :: (MonadIO m, MonadBaseControl IO m)
forkFinally1 ctx preExit =
EL.mask $ \restore ->
liftBaseWith $ \runInIO -> forkIO $ do
_ <- runInIO $ EL.try (restore (resume ctx)) >>= liftIO . preExit
_ <- runInIO $ EL.try (restore (runContext ctx))
>>= liftIO . preExit
return ()
-- | Run a given context in a new thread.
@ -263,27 +186,22 @@ forkContext context = do
{ parentChannel = Just (childChannel ctx)
, pendingThreads = pendingRef
, childChannel = chan
, accumResults = []
, accumResults = Nothing
}
beforeExit ctx res = do
tid <- myThreadId
exc <- case res of
Left e -> do
r <- case res of
Left e -> do
dbg $ "beforeExit: " ++ show tid ++ " caught exception"
liftIO $ killChildren ctx
return (Just e)
Right _ -> return Nothing
Right _ -> waitForChildren ctx
e <- waitForChildren ctx exc
-- We are guaranteed to have a parent because we have been explicitly
-- forked by some parent.
-- We are guaranteed to have a parent because we are forked.
let p = fromJust (parentChannel ctx)
signalQSemB (threadCredit ctx)
-- XXX change the return value type to Maybe SomeException
liftIO $ atomically $ writeTChan p
(ChildDone tid (maybe (Right []) Left e))
liftIO $ atomically $ writeTChan p (ChildDone tid r)
-- | Decide whether to resume the context in the same thread or a new thread
--
@ -318,7 +236,7 @@ resumeContextWith context synch action = do
let ctx = setContextMailBox context (action context)
can <- liftIO $ canFork context
case can && (not synch) of
False -> resume ctx -- run synchronously
False -> runContext ctx -- run synchronously
True -> forkContext ctx
-- | 'StreamData' represents a task in a task stream being generated.

View File

@ -1,5 +1,8 @@
import Strands
import Control.Monad.IO.Class (liftIO)
main = waitAsync $ do
liftIO $ putStrLn "hello"
main = do
xs <- waitAsync $ do
liftIO $ putStrLn "hello"
return 5
print xs