diff --git a/asterius/app/ahc-ld.hs b/asterius/app/ahc-ld.hs index e4f78d7d..effc89cd 100644 --- a/asterius/app/ahc-ld.hs +++ b/asterius/app/ahc-ld.hs @@ -2,6 +2,8 @@ {-# LANGUAGE ViewPatterns #-} import Asterius.Ld +import Control.Concurrent +import Control.Monad import Data.List import Data.Maybe import Data.String @@ -20,6 +22,10 @@ parseLinkTask args = do linkObjs = link_objs, linkLibs = link_libs, linkModule = mempty, + threadPoolSize = + maybe 1 read $ + find ("--thread-pool-size=" `isPrefixOf`) args + >>= stripPrefix "--thread-pool-size=", hasMain = "--no-main" `notElem` args, debug = "--debug" `elem` args, gcSections = "--no-gc-sections" `notElem` args, @@ -63,5 +69,6 @@ main = do rsp <- readFile rsp_path let rsp_args = map read $ lines rsp task <- parseLinkTask rsp_args + when (threadPoolSize task > 1) $ setNumCapabilities (threadPoolSize task) ignore <- isJust <$> getEnv "ASTERIUS_AHC_LD_IGNORE" if ignore then callProcess "touch" [linkOutput task] else linkExe task diff --git a/asterius/app/ahc-link.hs b/asterius/app/ahc-link.hs index 24a67fbf..d14d0111 100644 --- a/asterius/app/ahc-link.hs +++ b/asterius/app/ahc-link.hs @@ -1,4 +1,9 @@ import Asterius.Main +import Control.Concurrent +import Control.Monad main :: IO () -main = getTask >>= ahcLinkMain +main = do + task <- getTask + when (threadPoolSize task > 1) $ setNumCapabilities (threadPoolSize task) + ahcLinkMain task diff --git a/asterius/src/Asterius/Ar.hs b/asterius/src/Asterius/Ar.hs index b59d981b..0d9fbc60 100644 --- a/asterius/src/Asterius/Ar.hs +++ b/asterius/src/Asterius/Ar.hs @@ -4,22 +4,36 @@ module Asterius.Ar ( loadAr, + loadArchiveEntries, + loadArchiveEntry ) where import qualified Ar as GHC import Asterius.Binary.ByteString import Asterius.Types -import Data.Foldable import qualified IfaceEnv as GHC +-- | Load an archive file from disk, deserialize all objects it contains and +-- concatenate them into a single 'AsteriusCachedModule'. loadAr :: GHC.NameCacheUpdater -> FilePath -> IO AsteriusCachedModule -loadAr ncu p = do +loadAr ncu p = do -- TODO: This sequential version is currently being used by + -- Asterius.GHCi.Internals.asteriusIservCall + entries <- loadArchiveEntries p + mconcat <$> mapM (loadArchiveEntry ncu) entries + +-- | Load all the archive entries from an archive file @.a@, as a list of plain +-- 'ByteString's (content only). +{-# INLINE loadArchiveEntries #-} +loadArchiveEntries :: FilePath -> IO [GHC.ArchiveEntry] +loadArchiveEntries p = do GHC.Archive entries <- GHC.loadAr p - foldlM - ( \acc GHC.ArchiveEntry {..} -> tryGetBS ncu filedata >>= \case - Left _ -> pure acc - Right m -> pure $ m <> acc - ) - mempty - entries + return entries + +-- | Deserialize an 'GHC.ArchiveEntry'. In case deserialization fails, return +-- an empty 'AsteriusModule'. +loadArchiveEntry :: GHC.NameCacheUpdater -> GHC.ArchiveEntry -> IO AsteriusCachedModule +loadArchiveEntry ncu = \GHC.ArchiveEntry {..} -> + tryGetBS ncu filedata >>= \case + Left {} -> pure mempty + Right m -> pure m diff --git a/asterius/src/Asterius/Backends/Binaryen.hs b/asterius/src/Asterius/Backends/Binaryen.hs index 53f80c20..903fe95a 100644 --- a/asterius/src/Asterius/Backends/Binaryen.hs +++ b/asterius/src/Asterius/Backends/Binaryen.hs @@ -26,6 +26,7 @@ where import Asterius.Internals.Barf import Asterius.Internals.MagicNumber import Asterius.Internals.Marshal +import Asterius.Internals.Parallel import Asterius.Types import qualified Asterius.Types.SymbolMap as SM import Asterius.TypesConv @@ -525,24 +526,22 @@ marshalFunctionTable m tbl_slots FunctionTable {..} = flip runContT pure $ do (fromIntegral fnl) o +-- | Marshal the memory segments of a 'Module'. NOTE: It would be nice to +-- parallelize this process (see issue #621), but given that we want marshaling +-- to happen in @ContT@ for efficiency reasons, this might backfire. Leaving +-- linear for now. marshalMemorySegments :: Int -> [DataSegment] -> CodeGen () marshalMemorySegments mbs segs = do env <- ask m <- askModuleRef let segs_len = length segs + marshalOffset = \DataSegment {..} -> + lift $ flip runReaderT env $ marshalExpression $ ConstI32 offset lift $ flip runContT pure $ do (seg_bufs, _) <- marshalV =<< for segs (marshalBS . content) (seg_passives, _) <- marshalV $ replicate segs_len 0 - (seg_offsets, _) <- - marshalV - =<< for - segs - ( \DataSegment {..} -> - lift $ flip runReaderT env $ marshalExpression $ ConstI32 offset - ) - (seg_sizes, _) <- - marshalV $ - map (fromIntegral . BS.length . content) segs + (seg_offsets, _) <- marshalV =<< for segs marshalOffset + (seg_sizes, _) <- marshalV $ map (fromIntegral . BS.length . content) segs lift $ Binaryen.setMemory m @@ -571,8 +570,8 @@ marshalMemoryImport m MemoryImport {..} = flip runContT pure $ do lift $ Binaryen.addMemoryImport m inp emp ebp 0 marshalModule :: - Bool -> SM.SymbolMap Int64 -> Module -> IO Binaryen.Module -marshalModule tail_calls sym_map hs_mod@Module {..} = do + Bool -> Int -> SM.SymbolMap Int64 -> Module -> IO Binaryen.Module +marshalModule tail_calls pool_size sym_map hs_mod@Module {..} = do let fts = generateWasmFunctionTypeSet hs_mod m <- Binaryen.Module.create Binaryen.setFeatures m @@ -588,8 +587,8 @@ marshalModule tail_calls sym_map hs_mod@Module {..} = do envSymbolMap = sym_map, envModuleRef = m } - for_ (M.toList functionMap') $ \(k, f@Function {..}) -> - flip runReaderT env $ marshalFunction k (ftps M.! functionType) f + parallelFoldMap pool_size (M.toList functionMap') $ \(k, f@Function {..}) -> + flip runReaderT env $ void $ marshalFunction k (ftps M.! functionType) f forM_ functionImports $ \fi@FunctionImport {..} -> marshalFunctionImport m (ftps M.! functionType) fi forM_ functionExports $ marshalFunctionExport m diff --git a/asterius/src/Asterius/GHCi/Internals.hs b/asterius/src/Asterius/GHCi/Internals.hs index d0626c51..bb8a370a 100644 --- a/asterius/src/Asterius/GHCi/Internals.hs +++ b/asterius/src/Asterius/GHCi/Internals.hs @@ -295,6 +295,7 @@ asteriusWriteIServ hsc_env i a debug = False, gcSections = True, verboseErr = True, + threadPoolSize = 1, outputIR = Nothing, rootSymbols = [ run_q_exp_sym, diff --git a/asterius/src/Asterius/Internals/Parallel.hs b/asterius/src/Asterius/Internals/Parallel.hs new file mode 100644 index 00000000..e9735784 --- /dev/null +++ b/asterius/src/Asterius/Internals/Parallel.hs @@ -0,0 +1,60 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE ViewPatterns #-} + +-- | +-- Module : Asterius.Internals.Parallel +-- Copyright : (c) 2018 EURL Tweag +-- License : All rights reserved (see LICENCE file in the distribution). +-- +-- Simple parallel combinators. Since we need to control our dependency +-- surface, our current approach to parallelism is very simple: given the +-- worker thread pool capacity @c@ and the list of tasks to be performed, +-- 'parallelFoldMap' pins exactly @c@ threads on each of the capabilities, lets +-- them consume the input concurrently, and gathers the results using their +-- 'Monoid' instance. Notice that this behavior is deterministic only if '<>' +-- is also symmetric (not only associative), but that is sufficient for our +-- usecases. +-- +-- To avoid needless threading overhead, if @c = 1@ them we fall back to the +-- sequential implementation. +module Asterius.Internals.Parallel + ( parallelRnf, + parallelFoldMap, + ) +where + +import Control.Concurrent +import Control.Concurrent.MVar +import Control.DeepSeq +import Control.Exception +import Control.Monad +import Data.IORef +import System.IO.Unsafe + +-- | Given the worker thread pool capacity @c@, @parallelRnf c xs@ deeply +-- evaluates a list of objects in parallel on the global thread pool. +parallelRnf :: NFData a => Int -> [a] -> () +parallelRnf n xs + | n >= 2 = unsafePerformIO $ parallelFoldMap n xs (void . evaluate . force) + | otherwise = rnf xs + +-- | Given the worker thread pool capacity @c@, @parallelFoldMap c xs f@ maps @f@ +-- on @xs@ in parallel on the global thread pool, and concatenates the results. +parallelFoldMap :: (NFData r, Monoid r) => Int -> [a] -> (a -> IO r) -> IO r +parallelFoldMap n xs fn + | n >= 2 = do + input <- newIORef xs + mvars <- replicateM n newEmptyMVar + let getNextElem = atomicModifyIORef' input $ \case + [] -> ([], Nothing) + (y : ys) -> (ys, Just y) + loop mvar !acc = getNextElem >>= \case -- was (force -> !acc) + Nothing -> putMVar mvar acc + Just y -> do + !res <- fn y -- was: res <- fn y -- was (force -> !res) <- fn y + loop mvar (acc <> res) + forM_ ([0 ..] `zip` mvars) $ \(i, mvar) -> + forkOn i (loop mvar mempty) + mconcat <$> forM mvars takeMVar + | otherwise = mconcat <$> mapM fn xs diff --git a/asterius/src/Asterius/JSRun/NonMain.hs b/asterius/src/Asterius/JSRun/NonMain.hs index a0bedd27..a73d1b9c 100644 --- a/asterius/src/Asterius/JSRun/NonMain.hs +++ b/asterius/src/Asterius/JSRun/NonMain.hs @@ -38,6 +38,7 @@ linkNonMain store_m extra_syms = (m, link_report) Asterius.Ld.debug = False, Asterius.Ld.gcSections = True, Asterius.Ld.verboseErr = True, + Asterius.Ld.threadPoolSize = 1, Asterius.Ld.outputIR = Nothing, rootSymbols = extra_syms, Asterius.Ld.exportFunctions = [] @@ -60,6 +61,7 @@ distNonMain p extra_syms = yolo = True, Asterius.Main.Task.hasMain = False, Asterius.Main.Task.verboseErr = True, + Asterius.Main.Task.threadPoolSize = 1, extraRootSymbols = extra_syms } diff --git a/asterius/src/Asterius/Ld.hs b/asterius/src/Asterius/Ld.hs index c35d9f85..7a3cfd08 100644 --- a/asterius/src/Asterius/Ld.hs +++ b/asterius/src/Asterius/Ld.hs @@ -1,4 +1,5 @@ {-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE RecordWildCards #-} @@ -18,11 +19,11 @@ import Asterius.Binary.File import Asterius.Binary.NameCache import Asterius.Builtins import Asterius.Builtins.Main +import Asterius.Internals.Parallel import Asterius.Resolve import Asterius.Types import qualified Asterius.Types.SymbolSet as SS import Control.Exception -import Data.Either import Data.Traversable data LinkTask @@ -31,6 +32,7 @@ data LinkTask linkObjs, linkLibs :: [FilePath], linkModule :: AsteriusCachedModule, hasMain, debug, gcSections, verboseErr :: Bool, + threadPoolSize :: Int, outputIR :: Maybe FilePath, rootSymbols, exportFunctions :: [EntitySymbol] } @@ -45,9 +47,15 @@ data LinkTask loadTheWorld :: LinkTask -> IO AsteriusCachedModule loadTheWorld LinkTask {..} = do ncu <- newNameCacheUpdater - lib <- mconcat <$> for linkLibs (loadAr ncu) - objs <- rights <$> for linkObjs (tryGetFile ncu) - evaluate $ linkModule <> mconcat objs <> lib + lib <- do + entries <- concat <$> for linkLibs loadArchiveEntries + parallelFoldMap threadPoolSize entries (loadArchiveEntry ncu) + objs <- parallelFoldMap threadPoolSize linkObjs (loadObj ncu) + evaluate $ linkModule <> objs <> lib + where + loadObj ncu path = tryGetFile ncu path >>= \case + Left {} -> pure mempty + Right m -> pure m -- | The *_info are generated from Cmm using the INFO_TABLE macro. -- For example, see StgMiscClosures.cmm / Exception.cmm @@ -95,6 +103,7 @@ linkModules LinkTask {..} m = debug gcSections verboseErr + threadPoolSize ( toCachedModule ( (if hasMain then mainBuiltins else mempty) <> rtsAsteriusModule diff --git a/asterius/src/Asterius/Main.hs b/asterius/src/Asterius/Main.hs index e5b6e2d7..fd40ab4a 100644 --- a/asterius/src/Asterius/Main.hs +++ b/asterius/src/Asterius/Main.hs @@ -104,6 +104,11 @@ parseTask args = case err_msgs of in if i >= 0 && i <= 2 then t {shrinkLevel = i} else error "Shrink level must be [0..2]", + str_opt "thread-pool-size" $ \s t -> + let i = read s + in if i >= 1 + then t {threadPoolSize = i} + else error "Thread pool size must be positive", bool_opt "debug" $ \t -> t @@ -272,6 +277,7 @@ ahcLink task = do ] <> ["-optl--no-gc-sections" | not (gcSections task)] <> ["-optl--verbose-err" | verboseErr task] + <> ["-optl--thread-pool-size=" <> show (threadPoolSize task)] <> extraGHCFlags task <> [ "-optl--output-ir=" <> outputDirectory task @@ -313,6 +319,7 @@ ahcDistMain logger task (final_m, report) = do m_ref <- Binaryen.marshalModule (tailCalls task) + (threadPoolSize task) (staticsSymbolMap report <> functionSymbolMap report) final_m when (optimizeLevel task > 0 || shrinkLevel task > 0) $ do diff --git a/asterius/src/Asterius/Main/Task.hs b/asterius/src/Asterius/Main/Task.hs index c8e4e8c3..54d5b070 100644 --- a/asterius/src/Asterius/Main/Task.hs +++ b/asterius/src/Asterius/Main/Task.hs @@ -23,6 +23,7 @@ module Asterius.Main.Task verboseErr, yolo, consoleHistory, + threadPoolSize, extraGHCFlags, exportFunctions, extraRootSymbols, @@ -53,6 +54,7 @@ data Task outputDirectory :: FilePath, outputBaseName :: String, hasMain, validate, tailCalls, gcSections, bundle, debug, outputIR, run, verboseErr, yolo, consoleHistory :: Bool, + threadPoolSize :: Int, extraGHCFlags :: [String], exportFunctions, extraRootSymbols :: [EntitySymbol], gcThreshold :: Int @@ -79,6 +81,7 @@ defTask = Task verboseErr = False, yolo = False, consoleHistory = False, + threadPoolSize = 1, extraGHCFlags = [], exportFunctions = [], extraRootSymbols = [], diff --git a/asterius/src/Asterius/Resolve.hs b/asterius/src/Asterius/Resolve.hs index 8a825c59..99bf01e6 100644 --- a/asterius/src/Asterius/Resolve.hs +++ b/asterius/src/Asterius/Resolve.hs @@ -105,11 +105,12 @@ linkStart :: Bool -> Bool -> Bool -> + Int -> AsteriusCachedModule -> SS.SymbolSet -> [EntitySymbol] -> (AsteriusModule, Module, LinkReport) -linkStart debug gc_sections verbose_err store root_syms export_funcs = +linkStart debug gc_sections verbose_err pool_size store root_syms export_funcs = ( merged_m, result_m, mempty @@ -126,7 +127,7 @@ linkStart debug gc_sections verbose_err store root_syms export_funcs = merged_m0 | gc_sections = gcSections verbose_err store root_syms export_funcs | otherwise = fromCachedModule store - !merged_m0_evaluated = force merged_m0 + !merged_m0_evaluated = parForceAsteriusModule pool_size merged_m0 merged_m1 | debug = addMemoryTrap merged_m0_evaluated | otherwise = merged_m0_evaluated diff --git a/asterius/src/Asterius/Types.hs b/asterius/src/Asterius/Types.hs index 1463f8cf..f962d56c 100644 --- a/asterius/src/Asterius/Types.hs +++ b/asterius/src/Asterius/Types.hs @@ -16,7 +16,8 @@ module Asterius.Types AsteriusStaticsType (..), AsteriusStatics (..), AsteriusModule (..), - AsteriusCachedModule(..), + parForceAsteriusModule, + AsteriusCachedModule (..), toCachedModule, EntitySymbol, entityName, @@ -53,6 +54,7 @@ where import Asterius.Binary.Orphans () import Asterius.Binary.TH +import Asterius.Internals.Parallel import Asterius.NFData.TH import Asterius.Types.EntitySymbol import Asterius.Types.SymbolMap (SymbolMap) @@ -126,6 +128,21 @@ instance Semigroup AsteriusModule where instance Monoid AsteriusModule where mempty = AsteriusModule mempty mempty mempty mempty mempty +-- | Given the worker thread pool capacity @n@, @parForceAsteriusModule n m@ +-- deeply evaluates an 'AsteriusModule' m in parallel on the global thread +-- pool. To avoid needless threading overhead, if @n = 1@ them we fall back to +-- the sequential implementation. +parForceAsteriusModule :: Int -> AsteriusModule -> AsteriusModule +parForceAsteriusModule n m@(AsteriusModule sm se fm spt mod_ffi_state) + | n >= 2 = + parallelRnf n (SM.toList sm) + `seq` parallelRnf n (SM.toList se) + `seq` parallelRnf n (SM.toList fm) + `seq` parallelRnf n (SM.toList spt) + `seq` rnf mod_ffi_state + `seq` m + | otherwise = force m + -- | An 'AsteriusCachedModule' in an 'AsteriusModule' along with with all of -- its 'EntitySymbol' dependencies, as they are appear in the modules data -- segments and function definitions (see function 'toCachedModule'). @@ -158,7 +175,6 @@ toCachedModule m = where add :: Data a => SymbolMap a -> SymbolMap SymbolSet -> SymbolMap SymbolSet add = flip $ SM.foldrWithKey' (\k e -> SM.insert k (collectEntitySymbols e)) - -- Collect all entity symbols from an entity. collectEntitySymbols :: Data a => a -> SymbolSet collectEntitySymbols t