add NonblockingSTM type

This commit is contained in:
Mitchell Rosen 2023-11-27 21:25:44 -05:00
parent a5da3ba990
commit 30d26f47ac
3 changed files with 57 additions and 25 deletions

View File

@ -90,6 +90,7 @@ library
other-modules:
Ki.Internal.ByteCount
Ki.Internal.IO
Ki.Internal.NonblockingSTM
Ki.Internal.Scope
Ki.Internal.Thread

View File

@ -0,0 +1,36 @@
-- | STM minus retry. These STM actions are guaranteed not to block, and thus guaranteed not to be interrupted by an
-- async exception.
module Ki.Internal.NonblockingSTM
( NonblockingSTM,
nonblockingAtomically,
nonblockingThrowSTM,
-- * TVar
nonblockingReadTVar,
nonblockingWriteTVar',
)
where
import Control.Exception (Exception)
import Data.Coerce (coerce)
import GHC.Conc (STM, TVar, atomically, readTVar, throwSTM, writeTVar)
newtype NonblockingSTM a
= NonblockingSTM (STM a)
deriving newtype (Applicative, Functor, Monad)
nonblockingAtomically :: forall a. NonblockingSTM a -> IO a
nonblockingAtomically =
coerce @(STM a -> IO a) atomically
nonblockingThrowSTM :: forall e x. (Exception e) => e -> NonblockingSTM x
nonblockingThrowSTM =
coerce @(e -> STM x) throwSTM
nonblockingReadTVar :: forall a. TVar a -> NonblockingSTM a
nonblockingReadTVar =
coerce @(TVar a -> STM a) readTVar
nonblockingWriteTVar' :: forall a. TVar a -> a -> NonblockingSTM ()
nonblockingWriteTVar' var !x =
NonblockingSTM (writeTVar var x)

View File

@ -47,9 +47,9 @@ import GHC.Conc
)
import GHC.Conc.Sync (readTVarIO)
import GHC.IO (unsafeUnmask)
import Ki.Internal.ByteCount
import IntSupply (IntSupply)
import qualified IntSupply
import Ki.Internal.ByteCount
import Ki.Internal.IO
( IOResult (..),
UnexceptionalIO (..),
@ -59,6 +59,7 @@ import Ki.Internal.IO
unexceptionalTryEither,
uninterruptiblyMasked,
)
import Ki.Internal.NonblockingSTM
import Ki.Internal.Thread
-- | A scope.
@ -229,13 +230,13 @@ spawn
-- or MaskedUninterruptible, to avoid a branch on parentMaskingState.
interruptiblyMasked do
-- Record the thread as being about to start. Not allowed to retry.
atomically do
n <- readTVar statusVar
nonblockingAtomically do
n <- nonblockingReadTVar statusVar
assert (n >= -2) do
case n of
Open -> writeTVar statusVar $! n + 1
Closing -> throwSTM ScopeClosing
Closed -> throwSTM (ErrorCall "ki: scope closed")
Open -> nonblockingWriteTVar' statusVar (n + 1)
Closing -> nonblockingThrowSTM ScopeClosing
Closed -> nonblockingThrowSTM (ErrorCall "ki: scope closed")
childId <- IntSupply.next nextChildIdSupply
@ -245,11 +246,9 @@ spawn
childThreadId <- myThreadId
labelThread childThreadId label
case allocationLimit of
Nothing -> pure ()
Just bytes -> do
setAllocationCounter (byteCountToInt64 bytes)
enableAllocationLimit
for_ allocationLimit \bytes -> do
setAllocationCounter (byteCountToInt64 bytes)
enableAllocationLimit
let -- Action that sets the masking state from the current (MaskedInterruptible) to the requested one.
atRequestedMaskingState :: IO a -> IO a
@ -261,12 +260,12 @@ spawn
runUnexceptionalIO (action childId atRequestedMaskingState)
atomically (unrecordChild childrenVar childId)
nonblockingAtomically (unrecordChild childrenVar childId)
-- Record the child as having started. Not allowed to retry.
atomically do
n <- readTVar statusVar
writeTVar statusVar $! n - 1
nonblockingAtomically do
n <- nonblockingReadTVar statusVar
nonblockingWriteTVar' statusVar (n - 1)
recordChild childrenVar childId childThreadId
pure childThreadId
@ -275,23 +274,19 @@ spawn
--
-- * Flipping `Nothing` to `Just childThreadId` (common case: we record child before it unrecords itself)
-- * Flipping `Just _` to `Nothing` (uncommon case: we observe that a child already unrecorded itself)
--
-- Never retries.
recordChild :: TVar (IntMap ThreadId) -> Tid -> ThreadId -> STM ()
recordChild :: TVar (IntMap ThreadId) -> Tid -> ThreadId -> NonblockingSTM ()
recordChild childrenVar childId childThreadId = do
children <- readTVar childrenVar
writeTVar childrenVar $! IntMap.Lazy.alter (maybe (Just childThreadId) (const Nothing)) childId children
children <- nonblockingReadTVar childrenVar
nonblockingWriteTVar' childrenVar (IntMap.Lazy.alter (maybe (Just childThreadId) (const Nothing)) childId children)
-- Unrecord a child (ourselves) by either:
--
-- * Flipping `Just childThreadId` to `Nothing` (common case: parent recorded us first)
-- * Flipping `Nothing` to `Just undefined` (uncommon case: we terminate and unrecord before parent can record us).
--
-- Never retries.
unrecordChild :: TVar (IntMap ThreadId) -> Tid -> STM ()
unrecordChild :: TVar (IntMap ThreadId) -> Tid -> NonblockingSTM ()
unrecordChild childrenVar childId = do
children <- readTVar childrenVar
writeTVar childrenVar $! IntMap.Lazy.alter (maybe (Just undefined) (const Nothing)) childId children
children <- nonblockingReadTVar childrenVar
nonblockingWriteTVar' childrenVar (IntMap.Lazy.alter (maybe (Just undefined) (const Nothing)) childId children)
-- | Wait until all threads created within a scope terminate.
awaitAll :: Scope -> STM ()