dejafu/dejafu-tests/Examples/SearchParty.hs

262 lines
8.6 KiB
Haskell
Raw Normal View History

2016-05-26 15:54:13 +03:00
{-# LANGUAGE Rank2Types #-}
{-
The Search Party library:
https://github.com/barrucadu/search-party
Originally intended as a first nontrivial test for dejafu, I found a
bug with the initial implementation of result lists, which was later
replaced with a @Stream@ type. This is a trimmed down version with
most of the functions not essential to exhbiting the bug removed.
-}
-- | Concurrent nondeterministic search.
module Examples.SearchParty where
2016-05-26 15:54:13 +03:00
import Control.Concurrent.Classy.STM.TMVar (TMVar, newEmptyTMVar, readTMVar, isEmptyTMVar, putTMVar, tryPutTMVar, tryTakeTMVar)
import Control.Monad (unless, when)
import Control.Monad.Conc.Class
import Control.Monad.STM.Class
import Data.Functor (void)
import Data.Maybe (fromJust, isNothing)
-- test imports
import Data.List (permutations)
import Test.DejaFu (Predicate, Result(..), alwaysTrue2)
import Test.Framework (Test)
import Test.Framework.Providers.HUnit (hUnitTestToTests)
import Test.HUnit (test)
import Test.HUnit.DejaFu (testDejafu)
2016-05-26 15:54:13 +03:00
import Examples.SearchParty.Impredicative
tests :: [Test]
tests = hUnitTestToTests $ test
[ testDejafu concFilter "concurrent filter" (invPred checkResultLists)
]
-- | Filter a list concurrently.
concFilter :: MonadConc m => m [Int]
concFilter = unsafeRunFind $ [0..5] @! const True
-- | Invert the result of a predicate.
invPred :: Predicate a -> Predicate a
invPred p xs = let r = p xs in r { _pass = not (_pass r) }
-- | Check that two lists of results are equal, modulo order.
checkResultLists :: Eq a => Predicate [a]
checkResultLists = alwaysTrue2 checkLists where
checkLists (Right as) (Right bs) =
as `elem` permutations bs
checkLists a b = a == b
-------------------------------------------------------------------------------
-- | A value of type @Find m a@ represents a concurrent search
-- computation (happening in the 'MonadConc' monad @m@) which may
-- produce a value of type @a@, or fail. If a value can be returned,
-- one will be (although it's nondeterministic which one will actually
-- be returned). Usually you will be working with values of type @Find
-- IO a@, but the generality allows for testing.
--
-- You should prefer using the 'Applicative' instance over the 'Monad'
-- instance if you can, as the 'Applicative' preserves parallelism.
newtype Find m a = Find { unFind :: m (WorkItem m a) }
-------------------------------------------------------------------------------
-- Instances
-- | 'fmap' delays applying the function until the value is demanded,
-- to avoid blocking.
instance MonadConc m => Functor (Find m) where
fmap g (Find mf) = Find $ fmap g <$> mf
-- | '<*>' performs both computations in parallel, and immediately
-- fails as soon as one does, giving a symmetric short-circuiting
-- behaviour.
instance MonadConc m => Applicative (Find m) where
pure a = Find . workItem' $ Just a
(Find mf) <*> (Find ma) = Find $ do
f <- mf
a <- ma
success <- blockOn [void f, void a]
if success
then do
fres <- unsafeResult f
ares <- unsafeResult a
workItem' . Just $ fres ares
else workItem' Nothing
-- | '>>=' should be avoided, as it necessarily imposes sequencing,
-- and blocks until the value being bound has been computed.
instance MonadConc m => Monad (Find m) where
return = pure
fail _ = Find $ workItem' Nothing
(Find mf) >>= g = Find $ do
f <- mf
res <- result f
case res of
Just a -> unFind $ g a
Nothing -> fail ""
--------------------------------------------------------------------------------
-- Execution
-- | Unsafe version of 'runFind'. This will error at runtime if the
-- computation fails.
unsafeRunFind :: MonadConc m => Find m a -> m a
unsafeRunFind (Find mf) = mf >>= unsafeResult
--------------------------------------------------------------------------------
-- Basic Searches
-- | Return all elements of a list satisfying a predicate, the order
-- may not be consistent between executions.
(@!) :: MonadConc m => [a] -> (a -> Bool) -> Find m [a]
as @! f = allOf [if f a then success a else failure | a <- as]
-- | Search which always succeeds.
success :: MonadConc m => a -> Find m a
success = pure
-- | Search which always fails.
failure :: MonadConc m => Find m a
failure = fail ""
-- | Return all non-failing results, the order is nondeterministic.
allOf :: MonadConc m => [Find m a] -> Find m [a]
allOf [] = success []
allOf as = Find $ do
(var, kill) <- work False $ map unFind as
return $ workItem var id kill
-------------------------------------------------------------------------------
-- Combinators
-- INTERNAL --
-------------------------------------------------------------------------------
-- Types
2016-05-26 15:54:13 +03:00
-- See SearchPartyImpred.hs
-------------------------------------------------------------------------------
-- Processing work items
-- | Block until all computations interested in have successfully
-- completed. If any fail, this immediately returns 'False' and kills
-- the still-running ones.
blockOn :: MonadConc m => [WorkItem m ()] -> m Bool
blockOn fs = do
-- Block until one thing fails, or everything succeeds.
success <- atomically $ do
states <- mapM getState fs
case (HasFailed `elem` states, all (==HasSucceeded) states) of
(True, _) -> return False
(_, True) -> return True
_ -> retry
-- Kill everything if something failed.
unless success $ mapM_ (_killme . unWrap) fs
return success
-- | Get the result of a computation, this blocks until the result is
-- present, so be careful not to lose parallelism.
result :: MonadConc m => WorkItem m a -> m (Maybe a)
result f = fmap (_mapped $ unWrap f) <$> res where
res = atomically . readTMVar . _result $ unWrap f
-- | Unsafe version of 'result', this will error at runtime if the
-- computation fails.
unsafeResult :: MonadConc m => WorkItem m a -> m a
unsafeResult = fmap fromJust . result
-- | Get the current state of a work item.
getState :: MonadConc m => WorkItem m a -> STM m WorkState
getState f = do
empty <- isEmptyTMVar . _result $ unWrap f
if empty
then return StillComputing
else do
failed <- hasFailed f
return $ if failed then HasFailed else HasSucceeded
-- | Check if a work item has failed. If the computation has not
-- terminated, this immediately returns 'False'.
hasFailed :: MonadConc m => WorkItem m a -> STM m Bool
hasFailed f = do
working <- isEmptyTMVar . _result $ unWrap f
if working
then return False
else isNothing <$> readTMVar (_result $ unWrap f)
-------------------------------------------------------------------------------
-- Work stealing
-- | Push a batch of work to the queue, returning a 'TMVar' that can
-- be blocked on to get the result, and an action that can be used to
-- kill the computation. If the first argument is true, as soon as one
-- succeeds, the others are killed; otherwise all results are
-- gathered.
work :: MonadConc m => Bool -> [m (WorkItem m a)] -> m (TMVar (STM m) (Maybe [a]), m ())
work shortcircuit workitems = do
res <- atomically newEmptyTMVar
kill <- atomically newEmptyTMVar
caps <- getNumCapabilities
dtid <- fork $ driver caps res kill
killme <- atomically $ readTMVar kill
return (res, killme >> killThread dtid)
where
-- If there's only one capability don't bother with threads.
driver 1 res kill = do
atomically . putTMVar kill $ failit res
remaining <- newCRef workitems
process remaining res
-- Fork off as many threads as there are capabilities, and queue
-- up the remaining work.
driver caps res kill = do
remaining <- newCRef workitems
tids <- mapM (\cap -> forkOn cap $ process remaining res) [0..caps-1]
-- Construct an action to short-circuit the computation.
atomically . putTMVar kill $ failit res >> mapM_ killThread tids
-- If short-circuiting, block until there is a result then kill
-- any still-running threads.
when shortcircuit $ do
_ <- atomically $ readTMVar res
mapM_ killThread tids
-- Process a work item and store the result if it is a success,
-- otherwise continue.
process remaining res = do
mitem <- atomicModifyCRef remaining $ \rs -> if null rs then ([], Nothing) else (tail rs, Just $ head rs)
case mitem of
Just item -> do
fwrap <- item
maybea <- result fwrap
case maybea of
Just a -> atomically $ do
val <- tryTakeTMVar res
case val of
Just (Just as) -> putTMVar res $ Just (a:as)
_ -> putTMVar res $ Just [a]
Nothing -> process remaining res
Nothing -> failit res
-- Record that a computation failed.
failit res = atomically . void $ tryPutTMVar res Nothing