mirror of
https://github.com/tweag/asterius.git
synced 2024-10-26 09:21:04 +03:00
WIP (wip-parallel-experiment, squash)
This commit is contained in:
parent
af8575d5bb
commit
cc6d042884
@ -2,6 +2,8 @@
|
||||
{-# LANGUAGE ViewPatterns #-}
|
||||
|
||||
import Asterius.Ld
|
||||
import Control.Concurrent
|
||||
import Control.Monad
|
||||
import Data.List
|
||||
import Data.Maybe
|
||||
import Data.String
|
||||
@ -20,6 +22,10 @@ parseLinkTask args = do
|
||||
linkObjs = link_objs,
|
||||
linkLibs = link_libs,
|
||||
linkModule = mempty,
|
||||
threadPoolSize =
|
||||
maybe 1 read $
|
||||
find ("--thread-pool-size=" `isPrefixOf`) args
|
||||
>>= stripPrefix "--thread-pool-size=",
|
||||
hasMain = "--no-main" `notElem` args,
|
||||
debug = "--debug" `elem` args,
|
||||
gcSections = "--no-gc-sections" `notElem` args,
|
||||
@ -63,5 +69,6 @@ main = do
|
||||
rsp <- readFile rsp_path
|
||||
let rsp_args = map read $ lines rsp
|
||||
task <- parseLinkTask rsp_args
|
||||
when (threadPoolSize task > 1) $ setNumCapabilities (threadPoolSize task)
|
||||
ignore <- isJust <$> getEnv "ASTERIUS_AHC_LD_IGNORE"
|
||||
if ignore then callProcess "touch" [linkOutput task] else linkExe task
|
||||
|
@ -1,4 +1,9 @@
|
||||
import Asterius.Main
|
||||
import Control.Concurrent
|
||||
import Control.Monad
|
||||
|
||||
main :: IO ()
|
||||
main = getTask >>= ahcLinkMain
|
||||
main = do
|
||||
task <- getTask
|
||||
when (threadPoolSize task > 1) $ setNumCapabilities (threadPoolSize task)
|
||||
ahcLinkMain task
|
||||
|
@ -4,22 +4,36 @@
|
||||
|
||||
module Asterius.Ar
|
||||
( loadAr,
|
||||
loadArchiveEntries,
|
||||
loadArchiveEntry
|
||||
)
|
||||
where
|
||||
|
||||
import qualified Ar as GHC
|
||||
import Asterius.Binary.ByteString
|
||||
import Asterius.Types
|
||||
import Data.Foldable
|
||||
import qualified IfaceEnv as GHC
|
||||
|
||||
-- | Load an archive file from disk, deserialize all objects it contains and
|
||||
-- concatenate them into a single 'AsteriusCachedModule'.
|
||||
loadAr :: GHC.NameCacheUpdater -> FilePath -> IO AsteriusCachedModule
|
||||
loadAr ncu p = do
|
||||
loadAr ncu p = do -- TODO: This sequential version is currently being used by
|
||||
-- Asterius.GHCi.Internals.asteriusIservCall
|
||||
entries <- loadArchiveEntries p
|
||||
mconcat <$> mapM (loadArchiveEntry ncu) entries
|
||||
|
||||
-- | Load all the archive entries from an archive file @.a@, as a list of plain
|
||||
-- 'ByteString's (content only).
|
||||
{-# INLINE loadArchiveEntries #-}
|
||||
loadArchiveEntries :: FilePath -> IO [GHC.ArchiveEntry]
|
||||
loadArchiveEntries p = do
|
||||
GHC.Archive entries <- GHC.loadAr p
|
||||
foldlM
|
||||
( \acc GHC.ArchiveEntry {..} -> tryGetBS ncu filedata >>= \case
|
||||
Left _ -> pure acc
|
||||
Right m -> pure $ m <> acc
|
||||
)
|
||||
mempty
|
||||
entries
|
||||
return entries
|
||||
|
||||
-- | Deserialize an 'GHC.ArchiveEntry'. In case deserialization fails, return
|
||||
-- an empty 'AsteriusModule'.
|
||||
loadArchiveEntry :: GHC.NameCacheUpdater -> GHC.ArchiveEntry -> IO AsteriusCachedModule
|
||||
loadArchiveEntry ncu = \GHC.ArchiveEntry {..} ->
|
||||
tryGetBS ncu filedata >>= \case
|
||||
Left {} -> pure mempty
|
||||
Right m -> pure m
|
||||
|
@ -26,6 +26,7 @@ where
|
||||
import Asterius.Internals.Barf
|
||||
import Asterius.Internals.MagicNumber
|
||||
import Asterius.Internals.Marshal
|
||||
import Asterius.Internals.Parallel
|
||||
import Asterius.Types
|
||||
import qualified Asterius.Types.SymbolMap as SM
|
||||
import Asterius.TypesConv
|
||||
@ -525,24 +526,22 @@ marshalFunctionTable m tbl_slots FunctionTable {..} = flip runContT pure $ do
|
||||
(fromIntegral fnl)
|
||||
o
|
||||
|
||||
-- | Marshal the memory segments of a 'Module'. NOTE: It would be nice to
|
||||
-- parallelize this process (see issue #621), but given that we want marshaling
|
||||
-- to happen in @ContT@ for efficiency reasons, this might backfire. Leaving
|
||||
-- linear for now.
|
||||
marshalMemorySegments :: Int -> [DataSegment] -> CodeGen ()
|
||||
marshalMemorySegments mbs segs = do
|
||||
env <- ask
|
||||
m <- askModuleRef
|
||||
let segs_len = length segs
|
||||
marshalOffset = \DataSegment {..} ->
|
||||
lift $ flip runReaderT env $ marshalExpression $ ConstI32 offset
|
||||
lift $ flip runContT pure $ do
|
||||
(seg_bufs, _) <- marshalV =<< for segs (marshalBS . content)
|
||||
(seg_passives, _) <- marshalV $ replicate segs_len 0
|
||||
(seg_offsets, _) <-
|
||||
marshalV
|
||||
=<< for
|
||||
segs
|
||||
( \DataSegment {..} ->
|
||||
lift $ flip runReaderT env $ marshalExpression $ ConstI32 offset
|
||||
)
|
||||
(seg_sizes, _) <-
|
||||
marshalV $
|
||||
map (fromIntegral . BS.length . content) segs
|
||||
(seg_offsets, _) <- marshalV =<< for segs marshalOffset
|
||||
(seg_sizes, _) <- marshalV $ map (fromIntegral . BS.length . content) segs
|
||||
lift $
|
||||
Binaryen.setMemory
|
||||
m
|
||||
@ -571,8 +570,8 @@ marshalMemoryImport m MemoryImport {..} = flip runContT pure $ do
|
||||
lift $ Binaryen.addMemoryImport m inp emp ebp 0
|
||||
|
||||
marshalModule ::
|
||||
Bool -> SM.SymbolMap Int64 -> Module -> IO Binaryen.Module
|
||||
marshalModule tail_calls sym_map hs_mod@Module {..} = do
|
||||
Bool -> Int -> SM.SymbolMap Int64 -> Module -> IO Binaryen.Module
|
||||
marshalModule tail_calls pool_size sym_map hs_mod@Module {..} = do
|
||||
let fts = generateWasmFunctionTypeSet hs_mod
|
||||
m <- Binaryen.Module.create
|
||||
Binaryen.setFeatures m
|
||||
@ -588,8 +587,8 @@ marshalModule tail_calls sym_map hs_mod@Module {..} = do
|
||||
envSymbolMap = sym_map,
|
||||
envModuleRef = m
|
||||
}
|
||||
for_ (M.toList functionMap') $ \(k, f@Function {..}) ->
|
||||
flip runReaderT env $ marshalFunction k (ftps M.! functionType) f
|
||||
parallelFoldMap pool_size (M.toList functionMap') $ \(k, f@Function {..}) ->
|
||||
flip runReaderT env $ void $ marshalFunction k (ftps M.! functionType) f
|
||||
forM_ functionImports $ \fi@FunctionImport {..} ->
|
||||
marshalFunctionImport m (ftps M.! functionType) fi
|
||||
forM_ functionExports $ marshalFunctionExport m
|
||||
|
@ -295,6 +295,7 @@ asteriusWriteIServ hsc_env i a
|
||||
debug = False,
|
||||
gcSections = True,
|
||||
verboseErr = True,
|
||||
threadPoolSize = 1,
|
||||
outputIR = Nothing,
|
||||
rootSymbols =
|
||||
[ run_q_exp_sym,
|
||||
|
60
asterius/src/Asterius/Internals/Parallel.hs
Normal file
60
asterius/src/Asterius/Internals/Parallel.hs
Normal file
@ -0,0 +1,60 @@
|
||||
{-# LANGUAGE BangPatterns #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE ViewPatterns #-}
|
||||
|
||||
-- |
|
||||
-- Module : Asterius.Internals.Parallel
|
||||
-- Copyright : (c) 2018 EURL Tweag
|
||||
-- License : All rights reserved (see LICENCE file in the distribution).
|
||||
--
|
||||
-- Simple parallel combinators. Since we need to control our dependency
|
||||
-- surface, our current approach to parallelism is very simple: given the
|
||||
-- worker thread pool capacity @c@ and the list of tasks to be performed,
|
||||
-- 'parallelFoldMap' pins exactly @c@ threads on each of the capabilities, lets
|
||||
-- them consume the input concurrently, and gathers the results using their
|
||||
-- 'Monoid' instance. Notice that this behavior is deterministic only if '<>'
|
||||
-- is also symmetric (not only associative), but that is sufficient for our
|
||||
-- usecases.
|
||||
--
|
||||
-- To avoid needless threading overhead, if @c = 1@ them we fall back to the
|
||||
-- sequential implementation.
|
||||
module Asterius.Internals.Parallel
|
||||
( parallelRnf,
|
||||
parallelFoldMap,
|
||||
)
|
||||
where
|
||||
|
||||
import Control.Concurrent
|
||||
import Control.Concurrent.MVar
|
||||
import Control.DeepSeq
|
||||
import Control.Exception
|
||||
import Control.Monad
|
||||
import Data.IORef
|
||||
import System.IO.Unsafe
|
||||
|
||||
-- | Given the worker thread pool capacity @c@, @parallelRnf c xs@ deeply
|
||||
-- evaluates a list of objects in parallel on the global thread pool.
|
||||
parallelRnf :: NFData a => Int -> [a] -> ()
|
||||
parallelRnf n xs
|
||||
| n >= 2 = unsafePerformIO $ parallelFoldMap n xs (void . evaluate . force)
|
||||
| otherwise = rnf xs
|
||||
|
||||
-- | Given the worker thread pool capacity @c@, @parallelFoldMap c xs f@ maps @f@
|
||||
-- on @xs@ in parallel on the global thread pool, and concatenates the results.
|
||||
parallelFoldMap :: (NFData r, Monoid r) => Int -> [a] -> (a -> IO r) -> IO r
|
||||
parallelFoldMap n xs fn
|
||||
| n >= 2 = do
|
||||
input <- newIORef xs
|
||||
mvars <- replicateM n newEmptyMVar
|
||||
let getNextElem = atomicModifyIORef' input $ \case
|
||||
[] -> ([], Nothing)
|
||||
(y : ys) -> (ys, Just y)
|
||||
loop mvar !acc = getNextElem >>= \case -- was (force -> !acc)
|
||||
Nothing -> putMVar mvar acc
|
||||
Just y -> do
|
||||
!res <- fn y -- was: res <- fn y -- was (force -> !res) <- fn y
|
||||
loop mvar (acc <> res)
|
||||
forM_ ([0 ..] `zip` mvars) $ \(i, mvar) ->
|
||||
forkOn i (loop mvar mempty)
|
||||
mconcat <$> forM mvars takeMVar
|
||||
| otherwise = mconcat <$> mapM fn xs
|
@ -38,6 +38,7 @@ linkNonMain store_m extra_syms = (m, link_report)
|
||||
Asterius.Ld.debug = False,
|
||||
Asterius.Ld.gcSections = True,
|
||||
Asterius.Ld.verboseErr = True,
|
||||
Asterius.Ld.threadPoolSize = 1,
|
||||
Asterius.Ld.outputIR = Nothing,
|
||||
rootSymbols = extra_syms,
|
||||
Asterius.Ld.exportFunctions = []
|
||||
@ -60,6 +61,7 @@ distNonMain p extra_syms =
|
||||
yolo = True,
|
||||
Asterius.Main.Task.hasMain = False,
|
||||
Asterius.Main.Task.verboseErr = True,
|
||||
Asterius.Main.Task.threadPoolSize = 1,
|
||||
extraRootSymbols = extra_syms
|
||||
}
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
{-# LANGUAGE RecordWildCards #-}
|
||||
@ -18,11 +19,11 @@ import Asterius.Binary.File
|
||||
import Asterius.Binary.NameCache
|
||||
import Asterius.Builtins
|
||||
import Asterius.Builtins.Main
|
||||
import Asterius.Internals.Parallel
|
||||
import Asterius.Resolve
|
||||
import Asterius.Types
|
||||
import qualified Asterius.Types.SymbolSet as SS
|
||||
import Control.Exception
|
||||
import Data.Either
|
||||
import Data.Traversable
|
||||
|
||||
data LinkTask
|
||||
@ -31,6 +32,7 @@ data LinkTask
|
||||
linkObjs, linkLibs :: [FilePath],
|
||||
linkModule :: AsteriusCachedModule,
|
||||
hasMain, debug, gcSections, verboseErr :: Bool,
|
||||
threadPoolSize :: Int,
|
||||
outputIR :: Maybe FilePath,
|
||||
rootSymbols, exportFunctions :: [EntitySymbol]
|
||||
}
|
||||
@ -45,9 +47,15 @@ data LinkTask
|
||||
loadTheWorld :: LinkTask -> IO AsteriusCachedModule
|
||||
loadTheWorld LinkTask {..} = do
|
||||
ncu <- newNameCacheUpdater
|
||||
lib <- mconcat <$> for linkLibs (loadAr ncu)
|
||||
objs <- rights <$> for linkObjs (tryGetFile ncu)
|
||||
evaluate $ linkModule <> mconcat objs <> lib
|
||||
lib <- do
|
||||
entries <- concat <$> for linkLibs loadArchiveEntries
|
||||
parallelFoldMap threadPoolSize entries (loadArchiveEntry ncu)
|
||||
objs <- parallelFoldMap threadPoolSize linkObjs (loadObj ncu)
|
||||
evaluate $ linkModule <> objs <> lib
|
||||
where
|
||||
loadObj ncu path = tryGetFile ncu path >>= \case
|
||||
Left {} -> pure mempty
|
||||
Right m -> pure m
|
||||
|
||||
-- | The *_info are generated from Cmm using the INFO_TABLE macro.
|
||||
-- For example, see StgMiscClosures.cmm / Exception.cmm
|
||||
@ -95,6 +103,7 @@ linkModules LinkTask {..} m =
|
||||
debug
|
||||
gcSections
|
||||
verboseErr
|
||||
threadPoolSize
|
||||
( toCachedModule
|
||||
( (if hasMain then mainBuiltins else mempty)
|
||||
<> rtsAsteriusModule
|
||||
|
@ -104,6 +104,11 @@ parseTask args = case err_msgs of
|
||||
in if i >= 0 && i <= 2
|
||||
then t {shrinkLevel = i}
|
||||
else error "Shrink level must be [0..2]",
|
||||
str_opt "thread-pool-size" $ \s t ->
|
||||
let i = read s
|
||||
in if i >= 1
|
||||
then t {threadPoolSize = i}
|
||||
else error "Thread pool size must be positive",
|
||||
bool_opt "debug" $
|
||||
\t ->
|
||||
t
|
||||
@ -272,6 +277,7 @@ ahcLink task = do
|
||||
]
|
||||
<> ["-optl--no-gc-sections" | not (gcSections task)]
|
||||
<> ["-optl--verbose-err" | verboseErr task]
|
||||
<> ["-optl--thread-pool-size=" <> show (threadPoolSize task)]
|
||||
<> extraGHCFlags task
|
||||
<> [ "-optl--output-ir="
|
||||
<> outputDirectory task
|
||||
@ -313,6 +319,7 @@ ahcDistMain logger task (final_m, report) = do
|
||||
m_ref <-
|
||||
Binaryen.marshalModule
|
||||
(tailCalls task)
|
||||
(threadPoolSize task)
|
||||
(staticsSymbolMap report <> functionSymbolMap report)
|
||||
final_m
|
||||
when (optimizeLevel task > 0 || shrinkLevel task > 0) $ do
|
||||
|
@ -23,6 +23,7 @@ module Asterius.Main.Task
|
||||
verboseErr,
|
||||
yolo,
|
||||
consoleHistory,
|
||||
threadPoolSize,
|
||||
extraGHCFlags,
|
||||
exportFunctions,
|
||||
extraRootSymbols,
|
||||
@ -53,6 +54,7 @@ data Task
|
||||
outputDirectory :: FilePath,
|
||||
outputBaseName :: String,
|
||||
hasMain, validate, tailCalls, gcSections, bundle, debug, outputIR, run, verboseErr, yolo, consoleHistory :: Bool,
|
||||
threadPoolSize :: Int,
|
||||
extraGHCFlags :: [String],
|
||||
exportFunctions, extraRootSymbols :: [EntitySymbol],
|
||||
gcThreshold :: Int
|
||||
@ -79,6 +81,7 @@ defTask = Task
|
||||
verboseErr = False,
|
||||
yolo = False,
|
||||
consoleHistory = False,
|
||||
threadPoolSize = 1,
|
||||
extraGHCFlags = [],
|
||||
exportFunctions = [],
|
||||
extraRootSymbols = [],
|
||||
|
@ -105,11 +105,12 @@ linkStart ::
|
||||
Bool ->
|
||||
Bool ->
|
||||
Bool ->
|
||||
Int ->
|
||||
AsteriusCachedModule ->
|
||||
SS.SymbolSet ->
|
||||
[EntitySymbol] ->
|
||||
(AsteriusModule, Module, LinkReport)
|
||||
linkStart debug gc_sections verbose_err store root_syms export_funcs =
|
||||
linkStart debug gc_sections verbose_err pool_size store root_syms export_funcs =
|
||||
( merged_m,
|
||||
result_m,
|
||||
mempty
|
||||
@ -126,7 +127,7 @@ linkStart debug gc_sections verbose_err store root_syms export_funcs =
|
||||
merged_m0
|
||||
| gc_sections = gcSections verbose_err store root_syms export_funcs
|
||||
| otherwise = fromCachedModule store
|
||||
!merged_m0_evaluated = force merged_m0
|
||||
!merged_m0_evaluated = parForceAsteriusModule pool_size merged_m0
|
||||
merged_m1
|
||||
| debug = addMemoryTrap merged_m0_evaluated
|
||||
| otherwise = merged_m0_evaluated
|
||||
|
@ -16,7 +16,8 @@ module Asterius.Types
|
||||
AsteriusStaticsType (..),
|
||||
AsteriusStatics (..),
|
||||
AsteriusModule (..),
|
||||
AsteriusCachedModule(..),
|
||||
parForceAsteriusModule,
|
||||
AsteriusCachedModule (..),
|
||||
toCachedModule,
|
||||
EntitySymbol,
|
||||
entityName,
|
||||
@ -53,6 +54,7 @@ where
|
||||
|
||||
import Asterius.Binary.Orphans ()
|
||||
import Asterius.Binary.TH
|
||||
import Asterius.Internals.Parallel
|
||||
import Asterius.NFData.TH
|
||||
import Asterius.Types.EntitySymbol
|
||||
import Asterius.Types.SymbolMap (SymbolMap)
|
||||
@ -126,6 +128,21 @@ instance Semigroup AsteriusModule where
|
||||
instance Monoid AsteriusModule where
|
||||
mempty = AsteriusModule mempty mempty mempty mempty mempty
|
||||
|
||||
-- | Given the worker thread pool capacity @n@, @parForceAsteriusModule n m@
|
||||
-- deeply evaluates an 'AsteriusModule' m in parallel on the global thread
|
||||
-- pool. To avoid needless threading overhead, if @n = 1@ them we fall back to
|
||||
-- the sequential implementation.
|
||||
parForceAsteriusModule :: Int -> AsteriusModule -> AsteriusModule
|
||||
parForceAsteriusModule n m@(AsteriusModule sm se fm spt mod_ffi_state)
|
||||
| n >= 2 =
|
||||
parallelRnf n (SM.toList sm)
|
||||
`seq` parallelRnf n (SM.toList se)
|
||||
`seq` parallelRnf n (SM.toList fm)
|
||||
`seq` parallelRnf n (SM.toList spt)
|
||||
`seq` rnf mod_ffi_state
|
||||
`seq` m
|
||||
| otherwise = force m
|
||||
|
||||
-- | An 'AsteriusCachedModule' in an 'AsteriusModule' along with with all of
|
||||
-- its 'EntitySymbol' dependencies, as they are appear in the modules data
|
||||
-- segments and function definitions (see function 'toCachedModule').
|
||||
@ -158,7 +175,6 @@ toCachedModule m =
|
||||
where
|
||||
add :: Data a => SymbolMap a -> SymbolMap SymbolSet -> SymbolMap SymbolSet
|
||||
add = flip $ SM.foldrWithKey' (\k e -> SM.insert k (collectEntitySymbols e))
|
||||
|
||||
-- Collect all entity symbols from an entity.
|
||||
collectEntitySymbols :: Data a => a -> SymbolSet
|
||||
collectEntitySymbols t
|
||||
|
Loading…
Reference in New Issue
Block a user