dejafu/dejafu-tests/Examples/SearchParty.hs

296 lines
9.9 KiB
Haskell
Raw Normal View History

{-# LANGUAGE ImpredicativeTypes #-}
{-# LANGUAGE Rank2Types #-}
{-
The Search Party library:
https://github.com/barrucadu/search-party
Originally intended as a first nontrivial test for dejafu, I found a
bug with the initial implementation of result lists, which was later
replaced with a @Stream@ type. This is a trimmed down version with
most of the functions not essential to exhbiting the bug removed.
-}
-- | Concurrent nondeterministic search.
module Examples.SearchParty where
import Control.Concurrent.Classy.STM.TMVar (TMVar, newEmptyTMVar, newTMVar, readTMVar, isEmptyTMVar, putTMVar, tryPutTMVar, tryTakeTMVar)
import Control.Monad (unless, when)
import Control.Monad.Conc.Class
import Control.Monad.STM.Class
import Data.Functor (void)
import Data.Maybe (fromJust, isNothing)
import Unsafe.Coerce (unsafeCoerce)
-- test imports
import Data.List (permutations)
import Test.DejaFu (Predicate, Result(..), alwaysTrue2)
import Test.Framework (Test)
import Test.Framework.Providers.HUnit (hUnitTestToTests)
import Test.HUnit (test)
import Test.HUnit.DejaFu (testDejafu)
tests :: [Test]
tests = hUnitTestToTests $ test
[ testDejafu concFilter "concurrent filter" (invPred checkResultLists)
]
-- | Filter a list concurrently.
concFilter :: MonadConc m => m [Int]
concFilter = unsafeRunFind $ [0..5] @! const True
-- | Invert the result of a predicate.
invPred :: Predicate a -> Predicate a
invPred p xs = let r = p xs in r { _pass = not (_pass r) }
-- | Check that two lists of results are equal, modulo order.
checkResultLists :: Eq a => Predicate [a]
checkResultLists = alwaysTrue2 checkLists where
checkLists (Right as) (Right bs) =
as `elem` permutations bs
checkLists a b = a == b
-------------------------------------------------------------------------------
-- | A value of type @Find m a@ represents a concurrent search
-- computation (happening in the 'MonadConc' monad @m@) which may
-- produce a value of type @a@, or fail. If a value can be returned,
-- one will be (although it's nondeterministic which one will actually
-- be returned). Usually you will be working with values of type @Find
-- IO a@, but the generality allows for testing.
--
-- You should prefer using the 'Applicative' instance over the 'Monad'
-- instance if you can, as the 'Applicative' preserves parallelism.
newtype Find m a = Find { unFind :: m (WorkItem m a) }
-------------------------------------------------------------------------------
-- Instances
-- | 'fmap' delays applying the function until the value is demanded,
-- to avoid blocking.
instance MonadConc m => Functor (Find m) where
fmap g (Find mf) = Find $ fmap g <$> mf
-- | '<*>' performs both computations in parallel, and immediately
-- fails as soon as one does, giving a symmetric short-circuiting
-- behaviour.
instance MonadConc m => Applicative (Find m) where
pure a = Find . workItem' $ Just a
(Find mf) <*> (Find ma) = Find $ do
f <- mf
a <- ma
success <- blockOn [void f, void a]
if success
then do
fres <- unsafeResult f
ares <- unsafeResult a
workItem' . Just $ fres ares
else workItem' Nothing
-- | '>>=' should be avoided, as it necessarily imposes sequencing,
-- and blocks until the value being bound has been computed.
instance MonadConc m => Monad (Find m) where
return = pure
fail _ = Find $ workItem' Nothing
(Find mf) >>= g = Find $ do
f <- mf
res <- result f
case res of
Just a -> unFind $ g a
Nothing -> fail ""
--------------------------------------------------------------------------------
-- Execution
-- | Unsafe version of 'runFind'. This will error at runtime if the
-- computation fails.
unsafeRunFind :: MonadConc m => Find m a -> m a
unsafeRunFind (Find mf) = mf >>= unsafeResult
--------------------------------------------------------------------------------
-- Basic Searches
-- | Return all elements of a list satisfying a predicate, the order
-- may not be consistent between executions.
(@!) :: MonadConc m => [a] -> (a -> Bool) -> Find m [a]
as @! f = allOf [if f a then success a else failure | a <- as]
-- | Search which always succeeds.
success :: MonadConc m => a -> Find m a
success = pure
-- | Search which always fails.
failure :: MonadConc m => Find m a
failure = fail ""
-- | Return all non-failing results, the order is nondeterministic.
allOf :: MonadConc m => [Find m a] -> Find m [a]
allOf [] = success []
allOf as = Find $ do
(var, kill) <- work False $ map unFind as
return $ workItem var id kill
-------------------------------------------------------------------------------
-- Combinators
-- INTERNAL --
-------------------------------------------------------------------------------
-- Types
-- | A unit of work in a monad @m@ which will produce a final result
-- of type @a@.
newtype WorkItem m a = WorkItem { unWrap :: forall x. WorkItem' m x a }
instance Functor (WorkItem m) where
fmap f (WorkItem w) = workItem (_result w) (f . _mapped w) (_killme w)
-- | A unit of work in a monad @m@ producing a result of type @x@,
-- which will then be transformed into a value of type @a@.
data WorkItem' m x a = WorkItem'
{ _result :: TMVar (STM m) (Maybe x)
-- ^ The future result of the computation.
, _mapped :: x -> a
-- ^ Some post-processing to do.
, _killme :: m ()
-- ^ Fail the computation, if it's still running.
}
-- | The possible states that a work item may be in.
data WorkState = StillComputing | HasFailed | HasSucceeded
deriving (Eq)
-- | Construct a 'WorkItem'.
workItem :: TMVar (STM m) (Maybe x) -> (x -> a) -> m () -> WorkItem m a
workItem res mapp kill = wrap $ WorkItem' res mapp kill where
-- Really not nice, but I have had difficulty getting GHC to unify
-- @WorkItem' m x a@ with @forall x. WorkItem' m x a@
--
-- This needs ImpredicativeTypes in GHC 7.8.
wrap :: WorkItem' m x a -> WorkItem m a
wrap = WorkItem . unsafeCoerce
-- | Construct a 'WorkItem' containing a result.
workItem' :: MonadConc m => Maybe a -> m (WorkItem m a)
workItem' a = (\v -> workItem v id $ pure ()) <$> atomically (newTMVar a)
-------------------------------------------------------------------------------
-- Processing work items
-- | Block until all computations interested in have successfully
-- completed. If any fail, this immediately returns 'False' and kills
-- the still-running ones.
blockOn :: MonadConc m => [WorkItem m ()] -> m Bool
blockOn fs = do
-- Block until one thing fails, or everything succeeds.
success <- atomically $ do
states <- mapM getState fs
case (HasFailed `elem` states, all (==HasSucceeded) states) of
(True, _) -> return False
(_, True) -> return True
_ -> retry
-- Kill everything if something failed.
unless success $ mapM_ (_killme . unWrap) fs
return success
-- | Get the result of a computation, this blocks until the result is
-- present, so be careful not to lose parallelism.
result :: MonadConc m => WorkItem m a -> m (Maybe a)
result f = fmap (_mapped $ unWrap f) <$> res where
res = atomically . readTMVar . _result $ unWrap f
-- | Unsafe version of 'result', this will error at runtime if the
-- computation fails.
unsafeResult :: MonadConc m => WorkItem m a -> m a
unsafeResult = fmap fromJust . result
-- | Get the current state of a work item.
getState :: MonadConc m => WorkItem m a -> STM m WorkState
getState f = do
empty <- isEmptyTMVar . _result $ unWrap f
if empty
then return StillComputing
else do
failed <- hasFailed f
return $ if failed then HasFailed else HasSucceeded
-- | Check if a work item has failed. If the computation has not
-- terminated, this immediately returns 'False'.
hasFailed :: MonadConc m => WorkItem m a -> STM m Bool
hasFailed f = do
working <- isEmptyTMVar . _result $ unWrap f
if working
then return False
else isNothing <$> readTMVar (_result $ unWrap f)
-------------------------------------------------------------------------------
-- Work stealing
-- | Push a batch of work to the queue, returning a 'TMVar' that can
-- be blocked on to get the result, and an action that can be used to
-- kill the computation. If the first argument is true, as soon as one
-- succeeds, the others are killed; otherwise all results are
-- gathered.
work :: MonadConc m => Bool -> [m (WorkItem m a)] -> m (TMVar (STM m) (Maybe [a]), m ())
work shortcircuit workitems = do
res <- atomically newEmptyTMVar
kill <- atomically newEmptyTMVar
caps <- getNumCapabilities
dtid <- fork $ driver caps res kill
killme <- atomically $ readTMVar kill
return (res, killme >> killThread dtid)
where
-- If there's only one capability don't bother with threads.
driver 1 res kill = do
atomically . putTMVar kill $ failit res
remaining <- newCRef workitems
process remaining res
-- Fork off as many threads as there are capabilities, and queue
-- up the remaining work.
driver caps res kill = do
remaining <- newCRef workitems
tids <- mapM (\cap -> forkOn cap $ process remaining res) [0..caps-1]
-- Construct an action to short-circuit the computation.
atomically . putTMVar kill $ failit res >> mapM_ killThread tids
-- If short-circuiting, block until there is a result then kill
-- any still-running threads.
when shortcircuit $ do
_ <- atomically $ readTMVar res
mapM_ killThread tids
-- Process a work item and store the result if it is a success,
-- otherwise continue.
process remaining res = do
mitem <- atomicModifyCRef remaining $ \rs -> if null rs then ([], Nothing) else (tail rs, Just $ head rs)
case mitem of
Just item -> do
fwrap <- item
maybea <- result fwrap
case maybea of
Just a -> atomically $ do
val <- tryTakeTMVar res
case val of
Just (Just as) -> putTMVar res $ Just (a:as)
_ -> putTMVar res $ Just [a]
Nothing -> process remaining res
Nothing -> failit res
-- Record that a computation failed.
failit res = atomically . void $ tryPutTMVar res Nothing