{-# LANGUAGE Rank2Types #-} {- The Search Party library: https://github.com/barrucadu/search-party Originally intended as a first nontrivial test for dejafu, I found a bug with the initial implementation of result lists, which was later replaced with a @Stream@ type. This is a trimmed down version with most of the functions not essential to exhbiting the bug removed. -} -- | Concurrent nondeterministic search. module Examples.SearchParty where import Control.Concurrent.Classy.STM.TMVar (TMVar, newEmptyTMVar, readTMVar, isEmptyTMVar, putTMVar, tryPutTMVar, tryTakeTMVar) import Control.Monad (unless, when) import Control.Monad.Conc.Class import Control.Monad.STM.Class import Data.Functor (void) import Data.Maybe (fromJust, isNothing) -- test imports import Data.List (permutations) import Test.DejaFu (Predicate, Result(..), alwaysTrue2) import Test.Framework (Test) import Test.Framework.Providers.HUnit (hUnitTestToTests) import Test.HUnit (test) import Test.HUnit.DejaFu (testDejafu) import Examples.SearchParty.Impredicative tests :: [Test] tests = hUnitTestToTests $ test [ testDejafu concFilter "concurrent filter" (invPred checkResultLists) ] -- | Filter a list concurrently. concFilter :: MonadConc m => m [Int] concFilter = unsafeRunFind $ [0..5] @! const True -- | Invert the result of a predicate. invPred :: Predicate a -> Predicate a invPred p xs = let r = p xs in r { _pass = not (_pass r) } -- | Check that two lists of results are equal, modulo order. checkResultLists :: Eq a => Predicate [a] checkResultLists = alwaysTrue2 checkLists where checkLists (Right as) (Right bs) = as `elem` permutations bs checkLists a b = a == b ------------------------------------------------------------------------------- -- | A value of type @Find m a@ represents a concurrent search -- computation (happening in the 'MonadConc' monad @m@) which may -- produce a value of type @a@, or fail. If a value can be returned, -- one will be (although it's nondeterministic which one will actually -- be returned). Usually you will be working with values of type @Find -- IO a@, but the generality allows for testing. -- -- You should prefer using the 'Applicative' instance over the 'Monad' -- instance if you can, as the 'Applicative' preserves parallelism. newtype Find m a = Find { unFind :: m (WorkItem m a) } ------------------------------------------------------------------------------- -- Instances -- | 'fmap' delays applying the function until the value is demanded, -- to avoid blocking. instance MonadConc m => Functor (Find m) where fmap g (Find mf) = Find $ fmap g <$> mf -- | '<*>' performs both computations in parallel, and immediately -- fails as soon as one does, giving a symmetric short-circuiting -- behaviour. instance MonadConc m => Applicative (Find m) where pure a = Find . workItem' $ Just a (Find mf) <*> (Find ma) = Find $ do f <- mf a <- ma success <- blockOn [void f, void a] if success then do fres <- unsafeResult f ares <- unsafeResult a workItem' . Just $ fres ares else workItem' Nothing -- | '>>=' should be avoided, as it necessarily imposes sequencing, -- and blocks until the value being bound has been computed. instance MonadConc m => Monad (Find m) where return = pure fail _ = Find $ workItem' Nothing (Find mf) >>= g = Find $ do f <- mf res <- result f case res of Just a -> unFind $ g a Nothing -> fail "" -------------------------------------------------------------------------------- -- Execution -- | Unsafe version of 'runFind'. This will error at runtime if the -- computation fails. unsafeRunFind :: MonadConc m => Find m a -> m a unsafeRunFind (Find mf) = mf >>= unsafeResult -------------------------------------------------------------------------------- -- Basic Searches -- | Return all elements of a list satisfying a predicate, the order -- may not be consistent between executions. (@!) :: MonadConc m => [a] -> (a -> Bool) -> Find m [a] as @! f = allOf [if f a then success a else failure | a <- as] -- | Search which always succeeds. success :: MonadConc m => a -> Find m a success = pure -- | Search which always fails. failure :: MonadConc m => Find m a failure = fail "" -- | Return all non-failing results, the order is nondeterministic. allOf :: MonadConc m => [Find m a] -> Find m [a] allOf [] = success [] allOf as = Find $ do (var, kill) <- work False $ map unFind as return $ workItem var id kill ------------------------------------------------------------------------------- -- Combinators -- INTERNAL -- ------------------------------------------------------------------------------- -- Types -- See SearchPartyImpred.hs ------------------------------------------------------------------------------- -- Processing work items -- | Block until all computations interested in have successfully -- completed. If any fail, this immediately returns 'False' and kills -- the still-running ones. blockOn :: MonadConc m => [WorkItem m ()] -> m Bool blockOn fs = do -- Block until one thing fails, or everything succeeds. success <- atomically $ do states <- mapM getState fs case (HasFailed `elem` states, all (==HasSucceeded) states) of (True, _) -> return False (_, True) -> return True _ -> retry -- Kill everything if something failed. unless success $ mapM_ (_killme . unWrap) fs return success -- | Get the result of a computation, this blocks until the result is -- present, so be careful not to lose parallelism. result :: MonadConc m => WorkItem m a -> m (Maybe a) result f = fmap (_mapped $ unWrap f) <$> res where res = atomically . readTMVar . _result $ unWrap f -- | Unsafe version of 'result', this will error at runtime if the -- computation fails. unsafeResult :: MonadConc m => WorkItem m a -> m a unsafeResult = fmap fromJust . result -- | Get the current state of a work item. getState :: MonadConc m => WorkItem m a -> STM m WorkState getState f = do empty <- isEmptyTMVar . _result $ unWrap f if empty then return StillComputing else do failed <- hasFailed f return $ if failed then HasFailed else HasSucceeded -- | Check if a work item has failed. If the computation has not -- terminated, this immediately returns 'False'. hasFailed :: MonadConc m => WorkItem m a -> STM m Bool hasFailed f = do working <- isEmptyTMVar . _result $ unWrap f if working then return False else isNothing <$> readTMVar (_result $ unWrap f) ------------------------------------------------------------------------------- -- Work stealing -- | Push a batch of work to the queue, returning a 'TMVar' that can -- be blocked on to get the result, and an action that can be used to -- kill the computation. If the first argument is true, as soon as one -- succeeds, the others are killed; otherwise all results are -- gathered. work :: MonadConc m => Bool -> [m (WorkItem m a)] -> m (TMVar (STM m) (Maybe [a]), m ()) work shortcircuit workitems = do res <- atomically newEmptyTMVar kill <- atomically newEmptyTMVar caps <- getNumCapabilities dtid <- fork $ driver caps res kill killme <- atomically $ readTMVar kill return (res, killme >> killThread dtid) where -- If there's only one capability don't bother with threads. driver 1 res kill = do atomically . putTMVar kill $ failit res remaining <- newCRef workitems process remaining res -- Fork off as many threads as there are capabilities, and queue -- up the remaining work. driver caps res kill = do remaining <- newCRef workitems tids <- mapM (\cap -> forkOn cap $ process remaining res) [0..caps-1] -- Construct an action to short-circuit the computation. atomically . putTMVar kill $ failit res >> mapM_ killThread tids -- If short-circuiting, block until there is a result then kill -- any still-running threads. when shortcircuit $ do _ <- atomically $ readTMVar res mapM_ killThread tids -- Process a work item and store the result if it is a success, -- otherwise continue. process remaining res = do mitem <- atomicModifyCRef remaining $ \rs -> if null rs then ([], Nothing) else (tail rs, Just $ head rs) case mitem of Just item -> do fwrap <- item maybea <- result fwrap case maybea of Just a -> atomically $ do val <- tryTakeTMVar res case val of Just (Just as) -> putTMVar res $ Just (a:as) _ -> putTMVar res $ Just [a] Nothing -> process remaining res Nothing -> failit res -- Record that a computation failed. failit res = atomically . void $ tryPutTMVar res Nothing