mirror of
https://github.com/awkward-squad/ki.git
synced 2024-10-03 22:57:51 +03:00
refactoring, cleanup
This commit is contained in:
parent
59312dc5c1
commit
f5e9a45fa8
@ -24,13 +24,15 @@ where
|
||||
import Control.Exception
|
||||
import Control.Monad (join)
|
||||
import Data.Coerce (coerce)
|
||||
import Data.Maybe (isJust)
|
||||
import GHC.Base (maskAsyncExceptions#, maskUninterruptible#)
|
||||
import GHC.Conc (STM, ThreadId (ThreadId), catchSTM)
|
||||
import GHC.Exts (Int (I#), fork#, forkOn#)
|
||||
import GHC.IO (IO (IO))
|
||||
import Prelude
|
||||
|
||||
-- A little promise that this IO action cannot throw an exception.
|
||||
-- A little promise that this IO action cannot throw an exception (*including* async exceptions, which you normally
|
||||
-- think of as being able to strike at any time).
|
||||
--
|
||||
-- Yeah it's verbose, and maybe not that necessary, but the code that bothers to use it really does require
|
||||
-- un-exceptiony IO actions for correctness, so here we are.
|
||||
@ -42,13 +44,17 @@ data IOResult a
|
||||
= Failure !SomeException -- sync or async exception
|
||||
| Success a
|
||||
|
||||
-- Try an action, catching any exception it throws.
|
||||
--
|
||||
-- The caller is responsible for ensuring that async exceptions are masked (at whatever masking level is appropriate),
|
||||
-- as (again) `UnexceptionalIO` implies async exceptions won't be thrown either.
|
||||
unexceptionalTry :: forall a. IO a -> UnexceptionalIO (IOResult a)
|
||||
unexceptionalTry action =
|
||||
UnexceptionalIO do
|
||||
(Success <$> action) `catch` \exception ->
|
||||
pure (Failure exception)
|
||||
|
||||
-- Like try, but with continuations. Also, catches all exceptions, because that's the only flavor we need.
|
||||
-- Like try, but with continuations.
|
||||
unexceptionalTryEither ::
|
||||
forall a b.
|
||||
(SomeException -> UnexceptionalIO b) ->
|
||||
@ -63,20 +69,18 @@ unexceptionalTryEither onFailure onSuccess action =
|
||||
(pure . coerce @_ @(SomeException -> IO b) onFailure)
|
||||
|
||||
isAsyncException :: SomeException -> Bool
|
||||
isAsyncException exception =
|
||||
case fromException @SomeAsyncException exception of
|
||||
Nothing -> False
|
||||
Just _ -> True
|
||||
isAsyncException =
|
||||
isJust . fromException @SomeAsyncException
|
||||
|
||||
-- | Call an action with asynchronous exceptions interruptibly masked.
|
||||
interruptiblyMasked :: IO a -> IO a
|
||||
interruptiblyMasked (IO io) =
|
||||
IO (maskAsyncExceptions# io)
|
||||
interruptiblyMasked :: forall a. IO a -> IO a
|
||||
interruptiblyMasked =
|
||||
coerce (maskAsyncExceptions# @a)
|
||||
|
||||
-- | Call an action with asynchronous exceptions uninterruptibly masked.
|
||||
uninterruptiblyMasked :: IO a -> IO a
|
||||
uninterruptiblyMasked (IO io) =
|
||||
IO (maskUninterruptible# io)
|
||||
uninterruptiblyMasked :: forall a. IO a -> IO a
|
||||
uninterruptiblyMasked =
|
||||
coerce (maskUninterruptible# @a)
|
||||
|
||||
-- Like try, but with continuations
|
||||
tryEitherSTM :: (Exception e) => (e -> STM b) -> (a -> STM b) -> STM a -> STM b
|
||||
|
@ -223,63 +223,74 @@ allocateScope = do
|
||||
|
||||
-- Spawn a thread in a scope, providing it its child id and a function that sets the masking state to the requested
|
||||
-- masking state. The given action is called with async exceptions interruptibly masked.
|
||||
spawn :: Scope -> ThreadOptions -> (Tid -> (forall x. IO x -> IO x) -> UnexceptionalIO ()) -> IO ThreadId
|
||||
spawn
|
||||
Scope {childrenVar, nextChildIdSupply, statusVar}
|
||||
ThreadOptions {affinity, allocationLimit, label, maskingState = requestedChildMaskingState}
|
||||
action = do
|
||||
-- Interruptible mask is enough so long as none of the STM operations below block.
|
||||
--
|
||||
-- Unconditionally set masking state to MaskedInterruptible, even though we might already be at MaskedInterruptible
|
||||
-- or MaskedUninterruptible, to avoid a branch on parentMaskingState.
|
||||
interruptiblyMasked do
|
||||
-- Record the thread as being about to start. Not allowed to retry.
|
||||
nonblockingAtomically do
|
||||
n <- nonblockingReadTVar statusVar
|
||||
assert (n >= -2) do
|
||||
case n of
|
||||
Open -> nonblockingWriteTVar' statusVar (n + 1)
|
||||
Closing -> nonblockingThrowSTM ScopeClosing
|
||||
Closed -> nonblockingThrowSTM (ErrorCall "ki: scope closed")
|
||||
spawn :: Scope -> ThreadOptions -> (Tid -> (forall x. IO x -> IO x) -> UnexceptionalIO ()) -> IO ChildIds
|
||||
spawn scope@Scope {childrenVar, statusVar} options action = do
|
||||
-- Interruptible mask is enough so long as none of the STM operations below block.
|
||||
--
|
||||
-- Unconditionally set masking state to MaskedInterruptible, even though we might already be at MaskedInterruptible
|
||||
-- or MaskedUninterruptible, to avoid a branch on parentMaskingState.
|
||||
interruptiblyMasked do
|
||||
-- Record the thread as being about to start. Not allowed to retry.
|
||||
nonblockingAtomically do
|
||||
status <- nonblockingReadTVar statusVar
|
||||
assert (status >= -2) do
|
||||
case status of
|
||||
Open -> nonblockingWriteTVar' statusVar (status + 1)
|
||||
Closing -> nonblockingThrowSTM ScopeClosing
|
||||
Closed -> nonblockingThrowSTM (ErrorCall "ki: scope closed")
|
||||
|
||||
childId <- IntSupply.next nextChildIdSupply
|
||||
childIds <- spawnChild scope options action
|
||||
|
||||
childThreadId <-
|
||||
forkWithAffinity affinity do
|
||||
when (not (null label)) do
|
||||
childThreadId <- myThreadId
|
||||
labelThread childThreadId label
|
||||
-- Record the child as having started. Not allowed to retry.
|
||||
nonblockingAtomically do
|
||||
starting <- nonblockingReadTVar statusVar
|
||||
assert (starting >= 1) do
|
||||
nonblockingWriteTVar' statusVar (starting - 1)
|
||||
recordChild childrenVar childIds
|
||||
|
||||
for_ allocationLimit \bytes -> do
|
||||
setAllocationCounter (byteCountToInt64 bytes)
|
||||
enableAllocationLimit
|
||||
pure childIds
|
||||
|
||||
let -- Action that sets the masking state from the current (MaskedInterruptible) to the requested one.
|
||||
atRequestedMaskingState :: IO a -> IO a
|
||||
atRequestedMaskingState =
|
||||
case requestedChildMaskingState of
|
||||
Unmasked -> unsafeUnmask
|
||||
MaskedInterruptible -> id
|
||||
MaskedUninterruptible -> uninterruptiblyMasked
|
||||
data ChildIds
|
||||
= ChildIds
|
||||
{-# UNPACK #-} !Tid
|
||||
{-# UNPACK #-} !ThreadId
|
||||
|
||||
runUnexceptionalIO (action childId atRequestedMaskingState)
|
||||
spawnChild :: Scope -> ThreadOptions -> (Tid -> (forall x. IO x -> IO x) -> UnexceptionalIO ()) -> IO ChildIds
|
||||
spawnChild scope options action = do
|
||||
childId <- IntSupply.next nextChildIdSupply
|
||||
childThreadId <-
|
||||
forkWithAffinity affinity do
|
||||
when (not (null label)) do
|
||||
childThreadId <- myThreadId
|
||||
labelThread childThreadId label
|
||||
|
||||
nonblockingAtomically (unrecordChild childrenVar childId)
|
||||
for_ allocationLimit \bytes -> do
|
||||
setAllocationCounter (byteCountToInt64 bytes)
|
||||
enableAllocationLimit
|
||||
|
||||
-- Record the child as having started. Not allowed to retry.
|
||||
nonblockingAtomically do
|
||||
n <- nonblockingReadTVar statusVar
|
||||
nonblockingWriteTVar' statusVar (n - 1)
|
||||
recordChild childrenVar childId childThreadId
|
||||
let -- Action that sets the masking state from the current (MaskedInterruptible) to the requested one.
|
||||
atRequestedMaskingState :: IO a -> IO a
|
||||
atRequestedMaskingState =
|
||||
case requestedChildMaskingState of
|
||||
Unmasked -> unsafeUnmask
|
||||
MaskedInterruptible -> id
|
||||
MaskedUninterruptible -> uninterruptiblyMasked
|
||||
|
||||
pure childThreadId
|
||||
runUnexceptionalIO (action childId atRequestedMaskingState)
|
||||
|
||||
nonblockingAtomically (unrecordChild childrenVar childId)
|
||||
pure (ChildIds childId childThreadId)
|
||||
where
|
||||
Scope {childrenVar, nextChildIdSupply} = scope
|
||||
ThreadOptions {affinity, allocationLimit, label, maskingState = requestedChildMaskingState} = options
|
||||
{-# INLINE spawnChild #-}
|
||||
|
||||
-- Record our child by either:
|
||||
--
|
||||
-- * Flipping `Nothing` to `Just childThreadId` (common case: we record child before it unrecords itself)
|
||||
-- * Flipping `Just _` to `Nothing` (uncommon case: we observe that a child already unrecorded itself)
|
||||
recordChild :: TVar (IntMap ThreadId) -> Tid -> ThreadId -> NonblockingSTM ()
|
||||
recordChild childrenVar childId childThreadId = do
|
||||
recordChild :: TVar (IntMap ThreadId) -> ChildIds -> NonblockingSTM ()
|
||||
recordChild childrenVar (ChildIds childId childThreadId) = do
|
||||
children <- nonblockingReadTVar childrenVar
|
||||
nonblockingWriteTVar' childrenVar (IntMap.Lazy.alter (maybe (Just childThreadId) (const Nothing)) childId children)
|
||||
|
||||
@ -298,7 +309,7 @@ awaitAll Scope {childrenVar, statusVar} = do
|
||||
children <- readTVar childrenVar
|
||||
guard (IntMap.Lazy.null children)
|
||||
status <- readTVar statusVar
|
||||
case status of
|
||||
assert (status >= -2) case status of
|
||||
Open -> guard (status == 0)
|
||||
Closing -> retry -- block until closed
|
||||
Closed -> pure ()
|
||||
@ -321,14 +332,12 @@ forkWith :: Scope -> ThreadOptions -> IO a -> IO (Thread a)
|
||||
forkWith scope opts action = do
|
||||
resultVar <- newTVarIO NoResultYet
|
||||
let done result = UnexceptionalIO (atomically (writeTVar resultVar result))
|
||||
ident <-
|
||||
ChildIds _ childThreadId <-
|
||||
spawn scope opts \childId masking -> do
|
||||
result <- unexceptionalTry (masking action)
|
||||
case result of
|
||||
unexceptionalTry (masking action) >>= \case
|
||||
Failure exception -> do
|
||||
when
|
||||
(not (isScopeClosingException exception))
|
||||
(propagateException scope childId exception)
|
||||
when (not (isScopeClosingException exception)) do
|
||||
propagateException scope childId exception
|
||||
-- even put async exceptions that we propagated. this isn't totally ideal because a caller awaiting this
|
||||
-- thread would not be able to distinguish between async exceptions delivered to this thread, or itself
|
||||
done (BadResult exception)
|
||||
@ -338,7 +347,7 @@ forkWith scope opts action = do
|
||||
NoResultYet -> retry
|
||||
BadResult exception -> throwSTM exception
|
||||
GoodResult value -> pure value
|
||||
pure (makeThread ident doAwait)
|
||||
pure (makeThread childThreadId doAwait)
|
||||
|
||||
-- | Variant of 'Ki.forkWith' for threads that never return.
|
||||
forkWith_ :: Scope -> ThreadOptions -> IO Void -> IO ()
|
||||
@ -369,7 +378,7 @@ forkTryWith :: forall e a. (Exception e) => Scope -> ThreadOptions -> IO a -> IO
|
||||
forkTryWith scope opts action = do
|
||||
resultVar <- newTVarIO NoResultYet
|
||||
let done result = UnexceptionalIO (atomically (writeTVar resultVar result))
|
||||
childThreadId <-
|
||||
ChildIds _ childThreadId <-
|
||||
spawn scope opts \childId masking -> do
|
||||
result <- unexceptionalTry (masking action)
|
||||
case result of
|
||||
@ -427,7 +436,7 @@ forkTryWith scope opts action = do
|
||||
propagateException :: Scope -> Tid -> SomeException -> UnexceptionalIO ()
|
||||
propagateException Scope {childExceptionVar, parentThreadId, statusVar} childId exception =
|
||||
UnexceptionalIO (readTVarIO statusVar) >>= \case
|
||||
Closing -> tryPutChildExceptionVar -- (A) / (B)
|
||||
Closing -> tryPutChildExceptionVar -- (A) or (B), we don't care which
|
||||
status -> assert (status >= 0) loop -- we know status is Open here
|
||||
where
|
||||
loop :: UnexceptionalIO ()
|
||||
|
Loading…
Reference in New Issue
Block a user