mirror of
https://github.com/awkward-squad/ki.git
synced 2024-11-20 17:52:27 +03:00
add NonblockingSTM type
This commit is contained in:
parent
a5da3ba990
commit
30d26f47ac
@ -90,6 +90,7 @@ library
|
||||
other-modules:
|
||||
Ki.Internal.ByteCount
|
||||
Ki.Internal.IO
|
||||
Ki.Internal.NonblockingSTM
|
||||
Ki.Internal.Scope
|
||||
Ki.Internal.Thread
|
||||
|
||||
|
36
ki/src/Ki/Internal/NonblockingSTM.hs
Normal file
36
ki/src/Ki/Internal/NonblockingSTM.hs
Normal file
@ -0,0 +1,36 @@
|
||||
-- | STM minus retry. These STM actions are guaranteed not to block, and thus guaranteed not to be interrupted by an
|
||||
-- async exception.
|
||||
module Ki.Internal.NonblockingSTM
|
||||
( NonblockingSTM,
|
||||
nonblockingAtomically,
|
||||
nonblockingThrowSTM,
|
||||
|
||||
-- * TVar
|
||||
nonblockingReadTVar,
|
||||
nonblockingWriteTVar',
|
||||
)
|
||||
where
|
||||
|
||||
import Control.Exception (Exception)
|
||||
import Data.Coerce (coerce)
|
||||
import GHC.Conc (STM, TVar, atomically, readTVar, throwSTM, writeTVar)
|
||||
|
||||
newtype NonblockingSTM a
|
||||
= NonblockingSTM (STM a)
|
||||
deriving newtype (Applicative, Functor, Monad)
|
||||
|
||||
nonblockingAtomically :: forall a. NonblockingSTM a -> IO a
|
||||
nonblockingAtomically =
|
||||
coerce @(STM a -> IO a) atomically
|
||||
|
||||
nonblockingThrowSTM :: forall e x. (Exception e) => e -> NonblockingSTM x
|
||||
nonblockingThrowSTM =
|
||||
coerce @(e -> STM x) throwSTM
|
||||
|
||||
nonblockingReadTVar :: forall a. TVar a -> NonblockingSTM a
|
||||
nonblockingReadTVar =
|
||||
coerce @(TVar a -> STM a) readTVar
|
||||
|
||||
nonblockingWriteTVar' :: forall a. TVar a -> a -> NonblockingSTM ()
|
||||
nonblockingWriteTVar' var !x =
|
||||
NonblockingSTM (writeTVar var x)
|
@ -47,9 +47,9 @@ import GHC.Conc
|
||||
)
|
||||
import GHC.Conc.Sync (readTVarIO)
|
||||
import GHC.IO (unsafeUnmask)
|
||||
import Ki.Internal.ByteCount
|
||||
import IntSupply (IntSupply)
|
||||
import qualified IntSupply
|
||||
import Ki.Internal.ByteCount
|
||||
import Ki.Internal.IO
|
||||
( IOResult (..),
|
||||
UnexceptionalIO (..),
|
||||
@ -59,6 +59,7 @@ import Ki.Internal.IO
|
||||
unexceptionalTryEither,
|
||||
uninterruptiblyMasked,
|
||||
)
|
||||
import Ki.Internal.NonblockingSTM
|
||||
import Ki.Internal.Thread
|
||||
|
||||
-- | A scope.
|
||||
@ -229,13 +230,13 @@ spawn
|
||||
-- or MaskedUninterruptible, to avoid a branch on parentMaskingState.
|
||||
interruptiblyMasked do
|
||||
-- Record the thread as being about to start. Not allowed to retry.
|
||||
atomically do
|
||||
n <- readTVar statusVar
|
||||
nonblockingAtomically do
|
||||
n <- nonblockingReadTVar statusVar
|
||||
assert (n >= -2) do
|
||||
case n of
|
||||
Open -> writeTVar statusVar $! n + 1
|
||||
Closing -> throwSTM ScopeClosing
|
||||
Closed -> throwSTM (ErrorCall "ki: scope closed")
|
||||
Open -> nonblockingWriteTVar' statusVar (n + 1)
|
||||
Closing -> nonblockingThrowSTM ScopeClosing
|
||||
Closed -> nonblockingThrowSTM (ErrorCall "ki: scope closed")
|
||||
|
||||
childId <- IntSupply.next nextChildIdSupply
|
||||
|
||||
@ -245,11 +246,9 @@ spawn
|
||||
childThreadId <- myThreadId
|
||||
labelThread childThreadId label
|
||||
|
||||
case allocationLimit of
|
||||
Nothing -> pure ()
|
||||
Just bytes -> do
|
||||
setAllocationCounter (byteCountToInt64 bytes)
|
||||
enableAllocationLimit
|
||||
for_ allocationLimit \bytes -> do
|
||||
setAllocationCounter (byteCountToInt64 bytes)
|
||||
enableAllocationLimit
|
||||
|
||||
let -- Action that sets the masking state from the current (MaskedInterruptible) to the requested one.
|
||||
atRequestedMaskingState :: IO a -> IO a
|
||||
@ -261,12 +260,12 @@ spawn
|
||||
|
||||
runUnexceptionalIO (action childId atRequestedMaskingState)
|
||||
|
||||
atomically (unrecordChild childrenVar childId)
|
||||
nonblockingAtomically (unrecordChild childrenVar childId)
|
||||
|
||||
-- Record the child as having started. Not allowed to retry.
|
||||
atomically do
|
||||
n <- readTVar statusVar
|
||||
writeTVar statusVar $! n - 1
|
||||
nonblockingAtomically do
|
||||
n <- nonblockingReadTVar statusVar
|
||||
nonblockingWriteTVar' statusVar (n - 1)
|
||||
recordChild childrenVar childId childThreadId
|
||||
|
||||
pure childThreadId
|
||||
@ -275,23 +274,19 @@ spawn
|
||||
--
|
||||
-- * Flipping `Nothing` to `Just childThreadId` (common case: we record child before it unrecords itself)
|
||||
-- * Flipping `Just _` to `Nothing` (uncommon case: we observe that a child already unrecorded itself)
|
||||
--
|
||||
-- Never retries.
|
||||
recordChild :: TVar (IntMap ThreadId) -> Tid -> ThreadId -> STM ()
|
||||
recordChild :: TVar (IntMap ThreadId) -> Tid -> ThreadId -> NonblockingSTM ()
|
||||
recordChild childrenVar childId childThreadId = do
|
||||
children <- readTVar childrenVar
|
||||
writeTVar childrenVar $! IntMap.Lazy.alter (maybe (Just childThreadId) (const Nothing)) childId children
|
||||
children <- nonblockingReadTVar childrenVar
|
||||
nonblockingWriteTVar' childrenVar (IntMap.Lazy.alter (maybe (Just childThreadId) (const Nothing)) childId children)
|
||||
|
||||
-- Unrecord a child (ourselves) by either:
|
||||
--
|
||||
-- * Flipping `Just childThreadId` to `Nothing` (common case: parent recorded us first)
|
||||
-- * Flipping `Nothing` to `Just undefined` (uncommon case: we terminate and unrecord before parent can record us).
|
||||
--
|
||||
-- Never retries.
|
||||
unrecordChild :: TVar (IntMap ThreadId) -> Tid -> STM ()
|
||||
unrecordChild :: TVar (IntMap ThreadId) -> Tid -> NonblockingSTM ()
|
||||
unrecordChild childrenVar childId = do
|
||||
children <- readTVar childrenVar
|
||||
writeTVar childrenVar $! IntMap.Lazy.alter (maybe (Just undefined) (const Nothing)) childId children
|
||||
children <- nonblockingReadTVar childrenVar
|
||||
nonblockingWriteTVar' childrenVar (IntMap.Lazy.alter (maybe (Just undefined) (const Nothing)) childId children)
|
||||
|
||||
-- | Wait until all threads created within a scope terminate.
|
||||
awaitAll :: Scope -> STM ()
|
||||
|
Loading…
Reference in New Issue
Block a user