refactoring, cleanup

This commit is contained in:
Mitchell Rosen 2023-11-27 22:33:30 -05:00
parent 59312dc5c1
commit f5e9a45fa8
2 changed files with 79 additions and 66 deletions

View File

@ -24,13 +24,15 @@ where
import Control.Exception
import Control.Monad (join)
import Data.Coerce (coerce)
import Data.Maybe (isJust)
import GHC.Base (maskAsyncExceptions#, maskUninterruptible#)
import GHC.Conc (STM, ThreadId (ThreadId), catchSTM)
import GHC.Exts (Int (I#), fork#, forkOn#)
import GHC.IO (IO (IO))
import Prelude
-- A little promise that this IO action cannot throw an exception.
-- A little promise that this IO action cannot throw an exception (*including* async exceptions, which you normally
-- think of as being able to strike at any time).
--
-- Yeah it's verbose, and maybe not that necessary, but the code that bothers to use it really does require
-- un-exceptiony IO actions for correctness, so here we are.
@ -42,13 +44,17 @@ data IOResult a
= Failure !SomeException -- sync or async exception
| Success a
-- Try an action, catching any exception it throws.
--
-- The caller is responsible for ensuring that async exceptions are masked (at whatever masking level is appropriate),
-- as (again) `UnexceptionalIO` implies async exceptions won't be thrown either.
unexceptionalTry :: forall a. IO a -> UnexceptionalIO (IOResult a)
unexceptionalTry action =
UnexceptionalIO do
(Success <$> action) `catch` \exception ->
pure (Failure exception)
-- Like try, but with continuations. Also, catches all exceptions, because that's the only flavor we need.
-- Like try, but with continuations.
unexceptionalTryEither ::
forall a b.
(SomeException -> UnexceptionalIO b) ->
@ -63,20 +69,18 @@ unexceptionalTryEither onFailure onSuccess action =
(pure . coerce @_ @(SomeException -> IO b) onFailure)
isAsyncException :: SomeException -> Bool
isAsyncException exception =
case fromException @SomeAsyncException exception of
Nothing -> False
Just _ -> True
isAsyncException =
isJust . fromException @SomeAsyncException
-- | Call an action with asynchronous exceptions interruptibly masked.
interruptiblyMasked :: IO a -> IO a
interruptiblyMasked (IO io) =
IO (maskAsyncExceptions# io)
interruptiblyMasked :: forall a. IO a -> IO a
interruptiblyMasked =
coerce (maskAsyncExceptions# @a)
-- | Call an action with asynchronous exceptions uninterruptibly masked.
uninterruptiblyMasked :: IO a -> IO a
uninterruptiblyMasked (IO io) =
IO (maskUninterruptible# io)
uninterruptiblyMasked :: forall a. IO a -> IO a
uninterruptiblyMasked =
coerce (maskUninterruptible# @a)
-- Like try, but with continuations
tryEitherSTM :: (Exception e) => (e -> STM b) -> (a -> STM b) -> STM a -> STM b

View File

@ -223,11 +223,8 @@ allocateScope = do
-- Spawn a thread in a scope, providing it its child id and a function that sets the masking state to the requested
-- masking state. The given action is called with async exceptions interruptibly masked.
spawn :: Scope -> ThreadOptions -> (Tid -> (forall x. IO x -> IO x) -> UnexceptionalIO ()) -> IO ThreadId
spawn
Scope {childrenVar, nextChildIdSupply, statusVar}
ThreadOptions {affinity, allocationLimit, label, maskingState = requestedChildMaskingState}
action = do
spawn :: Scope -> ThreadOptions -> (Tid -> (forall x. IO x -> IO x) -> UnexceptionalIO ()) -> IO ChildIds
spawn scope@Scope {childrenVar, statusVar} options action = do
-- Interruptible mask is enough so long as none of the STM operations below block.
--
-- Unconditionally set masking state to MaskedInterruptible, even though we might already be at MaskedInterruptible
@ -235,15 +232,32 @@ spawn
interruptiblyMasked do
-- Record the thread as being about to start. Not allowed to retry.
nonblockingAtomically do
n <- nonblockingReadTVar statusVar
assert (n >= -2) do
case n of
Open -> nonblockingWriteTVar' statusVar (n + 1)
status <- nonblockingReadTVar statusVar
assert (status >= -2) do
case status of
Open -> nonblockingWriteTVar' statusVar (status + 1)
Closing -> nonblockingThrowSTM ScopeClosing
Closed -> nonblockingThrowSTM (ErrorCall "ki: scope closed")
childId <- IntSupply.next nextChildIdSupply
childIds <- spawnChild scope options action
-- Record the child as having started. Not allowed to retry.
nonblockingAtomically do
starting <- nonblockingReadTVar statusVar
assert (starting >= 1) do
nonblockingWriteTVar' statusVar (starting - 1)
recordChild childrenVar childIds
pure childIds
data ChildIds
= ChildIds
{-# UNPACK #-} !Tid
{-# UNPACK #-} !ThreadId
spawnChild :: Scope -> ThreadOptions -> (Tid -> (forall x. IO x -> IO x) -> UnexceptionalIO ()) -> IO ChildIds
spawnChild scope options action = do
childId <- IntSupply.next nextChildIdSupply
childThreadId <-
forkWithAffinity affinity do
when (not (null label)) do
@ -265,21 +279,18 @@ spawn
runUnexceptionalIO (action childId atRequestedMaskingState)
nonblockingAtomically (unrecordChild childrenVar childId)
-- Record the child as having started. Not allowed to retry.
nonblockingAtomically do
n <- nonblockingReadTVar statusVar
nonblockingWriteTVar' statusVar (n - 1)
recordChild childrenVar childId childThreadId
pure childThreadId
pure (ChildIds childId childThreadId)
where
Scope {childrenVar, nextChildIdSupply} = scope
ThreadOptions {affinity, allocationLimit, label, maskingState = requestedChildMaskingState} = options
{-# INLINE spawnChild #-}
-- Record our child by either:
--
-- * Flipping `Nothing` to `Just childThreadId` (common case: we record child before it unrecords itself)
-- * Flipping `Just _` to `Nothing` (uncommon case: we observe that a child already unrecorded itself)
recordChild :: TVar (IntMap ThreadId) -> Tid -> ThreadId -> NonblockingSTM ()
recordChild childrenVar childId childThreadId = do
recordChild :: TVar (IntMap ThreadId) -> ChildIds -> NonblockingSTM ()
recordChild childrenVar (ChildIds childId childThreadId) = do
children <- nonblockingReadTVar childrenVar
nonblockingWriteTVar' childrenVar (IntMap.Lazy.alter (maybe (Just childThreadId) (const Nothing)) childId children)
@ -298,7 +309,7 @@ awaitAll Scope {childrenVar, statusVar} = do
children <- readTVar childrenVar
guard (IntMap.Lazy.null children)
status <- readTVar statusVar
case status of
assert (status >= -2) case status of
Open -> guard (status == 0)
Closing -> retry -- block until closed
Closed -> pure ()
@ -321,14 +332,12 @@ forkWith :: Scope -> ThreadOptions -> IO a -> IO (Thread a)
forkWith scope opts action = do
resultVar <- newTVarIO NoResultYet
let done result = UnexceptionalIO (atomically (writeTVar resultVar result))
ident <-
ChildIds _ childThreadId <-
spawn scope opts \childId masking -> do
result <- unexceptionalTry (masking action)
case result of
unexceptionalTry (masking action) >>= \case
Failure exception -> do
when
(not (isScopeClosingException exception))
(propagateException scope childId exception)
when (not (isScopeClosingException exception)) do
propagateException scope childId exception
-- even put async exceptions that we propagated. this isn't totally ideal because a caller awaiting this
-- thread would not be able to distinguish between async exceptions delivered to this thread, or itself
done (BadResult exception)
@ -338,7 +347,7 @@ forkWith scope opts action = do
NoResultYet -> retry
BadResult exception -> throwSTM exception
GoodResult value -> pure value
pure (makeThread ident doAwait)
pure (makeThread childThreadId doAwait)
-- | Variant of 'Ki.forkWith' for threads that never return.
forkWith_ :: Scope -> ThreadOptions -> IO Void -> IO ()
@ -369,7 +378,7 @@ forkTryWith :: forall e a. (Exception e) => Scope -> ThreadOptions -> IO a -> IO
forkTryWith scope opts action = do
resultVar <- newTVarIO NoResultYet
let done result = UnexceptionalIO (atomically (writeTVar resultVar result))
childThreadId <-
ChildIds _ childThreadId <-
spawn scope opts \childId masking -> do
result <- unexceptionalTry (masking action)
case result of
@ -427,7 +436,7 @@ forkTryWith scope opts action = do
propagateException :: Scope -> Tid -> SomeException -> UnexceptionalIO ()
propagateException Scope {childExceptionVar, parentThreadId, statusVar} childId exception =
UnexceptionalIO (readTVarIO statusVar) >>= \case
Closing -> tryPutChildExceptionVar -- (A) / (B)
Closing -> tryPutChildExceptionVar -- (A) or (B), we don't care which
status -> assert (status >= 0) loop -- we know status is Open here
where
loop :: UnexceptionalIO ()