diff --git a/src/Asyncly/AsyncT.hs b/src/Asyncly/AsyncT.hs index fa872ee09..bb53b8d0a 100644 --- a/src/Asyncly/AsyncT.hs +++ b/src/Asyncly/AsyncT.hs @@ -26,10 +26,10 @@ where import Control.Applicative (Alternative (..)) import Control.Concurrent (ThreadId, forkIO, killThread, - myThreadId) + myThreadId, newQSem, QSem, + signalQSem, waitQSem) import Control.Concurrent.STM (TChan, atomically, newTChan, - readTChan, tryReadTChan, - writeTChan) + tryReadTChan, writeTChan) import Control.Exception (SomeException (..)) import qualified Control.Exception.Lifted as EL import Control.Monad (ap, liftM, MonadPlus(..), mzero) @@ -85,10 +85,10 @@ data ChildEvent a = ------------------------------------------------------------------------------ data Context a = - Context { childChannel :: TChan (ChildEvent a) - , pullSide :: Bool - , runningThreads :: IORef (Set ThreadId) - , doneThreads :: IORef (Set ThreadId) + Context { childChannel :: TChan (ChildEvent a) + , dispatchReq :: QSem + , runningThreads :: IORef (Set ThreadId) + , doneThreads :: IORef (Set ThreadId) } -- The 'Maybe (AsyncT m a)' is redundant as we can use 'stop' value for the @@ -182,7 +182,9 @@ doFork action exHandler = {-# NOINLINE push #-} push :: MonadIO m => Context a -> AsyncT m a -> m () -push context action = run (Just context) action +push context action = do + liftIO $ waitQSem (dispatchReq context) + run (Just context) action where @@ -203,24 +205,17 @@ push context action = run (Just context) action continue a ctx m = channelYield a >> run ctx m yielder a ctx r = maybe (done a) (\rx -> continue a ctx rx) r --- If an exception occurs we push it to the channel so that it can handled by --- the parent. 'Paused' exceptions are to be collected at the top level. --- XXX Paused exceptions should only bubble up to the runRecorder computation -{-# NOINLINE handleChildException #-} -handleChildException :: TChan (ChildEvent a) -> SomeException -> IO () -handleChildException pchan e = do - tid <- myThreadId - atomically $ writeTChan pchan (ChildStop tid (Just e)) - -{-# NOINLINE pushSideDispatch #-} -pushSideDispatch :: MonadAsync m - => Context a -> AsyncT m a -> AsyncT m a -> AsyncT m a -pushSideDispatch ctx m1 m2 = AsyncT $ \_ stp yld -> do - let chan = childChannel ctx - tid <- doFork (push ctx m1) (handleChildException chan) - liftIO $ atomically $ writeTChan chan (ChildCreate tid) - (runAsyncT m2) (Just ctx) stp yld - +-- Thread tracking has a significant performance overhead (~20% on empty +-- threads, it will be lower for heavy threads). It is needed for two reasons: +-- +-- 1) Killing threads on exceptions. Threads may not be allowed to go away by +-- themselves because they may run for significant times before going away or +-- worse they may be stuck in IO and never go away. +-- +-- 2) To know when all threads are done. This can be acheived by detecting a +-- BlockedIndefinitelyOnSTM exception too. But we will have to trigger a GC to +-- make sure that we detect it promptly. +-- -- This is a bit messy because ChildCreate and ChildDone events can arrive out -- of order in case of pushSideDispatch. Returns whether we are done draining -- threads. @@ -258,75 +253,49 @@ handleException e ctx tid = do -- We re-raise any exceptions received from the child threads, that way -- exceptions get propagated to the top level computation and can be handled -- there. -{-# NOINLINE pullDispatch #-} -pullDispatch :: (MonadIO m, MonadThrow m) - => Context a -> AsyncT m a -> Bool -> AsyncT m a -pullDispatch ctx m dispatch = AsyncT $ \_ stp yld -> do - if dispatch then do - (runAsyncT m) (Just ctx) stp yld - else do - res <- liftIO $ atomically $ tryReadTChan (childChannel ctx) - maybe (continue stp yld) (\ev -> handleEvent ev stp yld) res +{-# NOINLINE pullDrain #-} +pullDrain :: (MonadIO m, MonadThrow m) => Context a -> AsyncT m a +pullDrain ctx = AsyncT $ \_ stp yld -> do + res <- liftIO $ atomically $ tryReadTChan (childChannel ctx) + maybe (dispatch >> continue stp yld) + (\ev -> handleEvent ev stp yld) res where {-# INLINE continue #-} - continue stp yld = (runAsyncT (pullDispatch ctx m True)) (Just ctx) stp yld + continue stp yld = (runAsyncT (pullDrain ctx)) Nothing stp yld + + {-# INLINE dispatch #-} + dispatch = liftIO $ signalQSem (dispatchReq ctx) {-# INLINE handleEvent #-} - handleEvent ev stp yld = - case ev of - ChildYield a -> yld a Nothing (Just (pullDispatch ctx m False)) + handleEvent ev stp yld = do + let yielder a = yld a Nothing (Just (pullDrain ctx)) + case ev of + ChildYield a -> yielder a ChildDone tid a -> do - void $ delThread ctx tid - yld a Nothing (Just (pullDispatch ctx m True)) + dispatch + done <- delThread ctx tid + if done then (yld a Nothing Nothing) else (yielder a) ChildStop tid e -> do case e of - Nothing -> void (delThread ctx tid) >> continue stp yld + Nothing -> do + dispatch + done <- delThread ctx tid + if done then stp else continue stp yld Just x -> handleException x ctx tid - ChildCreate tid -> void (addThread ctx tid) >> continue stp yld + ChildCreate tid -> do + done <- addThread ctx tid + if done then stp else continue stp yld --- | run m1 in a new thread, pushing its results to a pull channel and then run --- m2 in the parent thread. Any exceptions are also pushed to the channel. -{-# INLINE pullSideDispatch #-} -pullSideDispatch :: MonadAsync m - => Context a -> AsyncT m a -> AsyncT m a -> AsyncT m a -pullSideDispatch ctx m1 m2 = AsyncT $ \_ stp yld -> do - let chan = childChannel ctx - tid <- doFork (push (ctx {pullSide = False}) m1) - (handleChildException chan) - liftIO $ modifyIORef (runningThreads ctx) $ (\s -> S.insert tid s) - (runAsyncT (pullDispatch ctx m2 False)) Nothing stp yld - -{-# NOINLINE pullDrain #-} -pullDrain :: (MonadIO m, MonadThrow m) => Context a -> AsyncT m a -pullDrain ctx = AsyncT $ \_ stp yld -> do - let yielder a = yld a Nothing (Just (pullDrain ctx)) - continue = (runAsyncT (pullDrain ctx)) Nothing stp yld - - res <- liftIO $ atomically $ readTChan (childChannel ctx) - case res of - ChildYield a -> yielder a - ChildDone tid a -> do - done <- delThread ctx tid - if done then (yld a Nothing Nothing) else (yielder a) - ChildStop tid e -> do - case e of - Nothing -> do - done <- delThread ctx tid - if done then stp else continue - Just x -> handleException x ctx tid - ChildCreate tid -> do - done <- addThread ctx tid - if done then stp else continue - -pullDrainStart :: (MonadIO m, MonadThrow m) => Context a -> AsyncT m a -pullDrainStart ctx = AsyncT $ \_ stp yld -> do - r <- liftIO $ readIORef (runningThreads ctx) - d <- liftIO $ readIORef (doneThreads ctx) - if (S.null r && S.null d) - then stp - else (runAsyncT (pullDrain ctx)) Nothing stp yld +-- If an exception occurs we push it to the channel so that it can handled by +-- the parent. 'Paused' exceptions are to be collected at the top level. +-- XXX Paused exceptions should only bubble up to the runRecorder computation +{-# NOINLINE handleChildException #-} +handleChildException :: TChan (ChildEvent a) -> SomeException -> IO () +handleChildException pchan e = do + tid <- myThreadId + atomically $ writeTChan pchan (ChildStop tid (Just e)) -- | Split the original computation in a pull-push pair. The original -- computation pulls from a Channel while m1 and m2 push to the channel. @@ -334,19 +303,30 @@ pullDrainStart ctx = AsyncT $ \_ stp yld -> do pullFork :: MonadAsync m => AsyncT m a -> AsyncT m a -> AsyncT m a pullFork m1 m2 = AsyncT $ \_ stp yld -> do ctx <- liftIO $ newContext - let m = pullDispatch ctx ((m1 <|> m2) <> pullDrainStart ctx) True - (runAsyncT m) Nothing stp yld + initialFork ctx m1 >> initialFork ctx m2 + (runAsyncT (pullDrain ctx)) Nothing stp yld where + -- This function is different than "fork" because we have to directly + -- insert the threadIds here and cannot use the channel to send ChildCreate + -- unlike on the push side. If we do that, the first thread's done message + -- may arrive even before the second thread is forked, in that case + -- pullDrain will falsely detect that all threads are over. + initialFork ctx m = do + let chan = childChannel ctx + tid <- doFork (push ctx m) (handleChildException chan) + liftIO $ modifyIORef (runningThreads ctx) $ (\s -> S.insert tid s) + newContext = do channel <- atomically newTChan running <- liftIO $ newIORef S.empty - done <- liftIO $ newIORef S.empty - return $ Context { childChannel = channel - , pullSide = True + done <- liftIO $ newIORef S.empty + count <- liftIO $ newQSem 1 + return $ Context { childChannel = channel + , dispatchReq = count , runningThreads = running - , doneThreads = done + , doneThreads = done } -- Concurrency rate control. Our objective is to create more threads on @@ -379,7 +359,14 @@ pullFork m1 m2 = AsyncT $ \_ stp yld -> do -- XXX to rate control left folded structrues we will have to return the -- residual work back to the dispatcher. It will also consume a lot of -- memory due to queueing of all the work before execution starts. --- + +{-# INLINE fork #-} +fork :: MonadAsync m => Context a -> AsyncT m a -> m () +fork ctx m = do + let chan = childChannel ctx + tid <- doFork (push ctx m) (handleChildException chan) + liftIO $ atomically $ writeTChan chan (ChildCreate tid) + instance MonadAsync m => Alternative (AsyncT m) where empty = mempty @@ -389,12 +376,7 @@ instance MonadAsync m => Alternative (AsyncT m) where m1 <|> m2 = AsyncT $ \ctx stp yld -> do case ctx of Nothing -> (runAsyncT (pullFork m1 m2)) ctx stp yld - Just c -> do - if pullSide c then - (runAsyncT (pullSideDispatch c m1 m2)) ctx stp yld - else - -- for left associated compositions - (runAsyncT (pushSideDispatch c m1 m2)) ctx stp yld + Just c -> fork c m2 >> (runAsyncT m1) ctx stp yld instance MonadAsync m => MonadPlus (AsyncT m) where mzero = empty