mirror of
https://github.com/anoma/juvix.git
synced 2025-01-03 04:58:50 +03:00
Inlining (#2036)
* Closes #1989 * Adds optimization phases to the pipline (specified by `opt-phase-eval`, `opt-phase-exec` and `opt-phase-geb` transformations). * Adds the `-O` option to the `compile` command to specify the optimization level. * Functions can be declared for inlining with the `inline` pragma: ``` {-# inline: true #-} const : {A B : Type} -> A -> B -> A; const x _ := x; ``` By default, the function is inlined only if it's fully applied. One can specify that a function (partially) applied to at least `n` explicit arguments should be inlined. ``` {-# inline: 2 #-} compose : {A B C : Type} -> (B -> C) -> (A -> B) -> A -> C; compose f g x := f (g x); ``` Then `compose f g` will be inlined, even though it's not fully applied. But `compose f` won't be inlined. * Non-recursive fully applied functions are automatically inlined if the height of the body term does not exceed the inlining depth limit, which can be specified with the `--inline` option to the `compile` command. * The pragma `inline: false` disables automatic inlining on a per-function basis.
This commit is contained in:
parent
ebeef381e6
commit
8aa54ecc28
@ -23,7 +23,9 @@ getEntry PipelineArg {..} = do
|
||||
return $
|
||||
ep
|
||||
{ _entryPointTarget = getTarget (_pipelineArgOptions ^. compileTarget),
|
||||
_entryPointDebug = _pipelineArgOptions ^. compileDebug
|
||||
_entryPointDebug = _pipelineArgOptions ^. compileDebug,
|
||||
_entryPointOptimizationLevel = _pipelineArgOptions ^. compileOptimizationLevel,
|
||||
_entryPointInliningDepth = _pipelineArgOptions ^. compileInliningDepth
|
||||
}
|
||||
where
|
||||
getTarget :: CompileTarget -> Backend.Target
|
||||
|
@ -1,6 +1,7 @@
|
||||
module Commands.Extra.Compile.Options where
|
||||
|
||||
import CommonOptions hiding (show)
|
||||
import Juvix.Compiler.Pipeline.EntryPoint
|
||||
import Prelude (Show (show))
|
||||
|
||||
data CompileTarget
|
||||
@ -9,7 +10,7 @@ data CompileTarget
|
||||
| TargetGeb
|
||||
| TargetCore
|
||||
| TargetAsm
|
||||
deriving stock (Data, Bounded, Enum)
|
||||
deriving stock (Eq, Data, Bounded, Enum)
|
||||
|
||||
instance Show CompileTarget where
|
||||
show = \case
|
||||
@ -27,7 +28,9 @@ data CompileOptions = CompileOptions
|
||||
_compileTerm :: Bool,
|
||||
_compileOutputFile :: Maybe (AppPath File),
|
||||
_compileTarget :: CompileTarget,
|
||||
_compileInputFile :: AppPath File
|
||||
_compileInputFile :: AppPath File,
|
||||
_compileOptimizationLevel :: Int,
|
||||
_compileInliningDepth :: Int
|
||||
}
|
||||
deriving stock (Data)
|
||||
|
||||
@ -68,10 +71,29 @@ parseCompileOptions supportedTargets parseInputFile = do
|
||||
<> help "Produce assembly output only (for targets: wasm32-wasi, native)"
|
||||
)
|
||||
_compileTerm <-
|
||||
switch
|
||||
( short 'G'
|
||||
<> long "only-term"
|
||||
<> help "Produce term output only (for targets: geb)"
|
||||
if
|
||||
| elem TargetGeb supportedTargets ->
|
||||
switch
|
||||
( short 'G'
|
||||
<> long "only-term"
|
||||
<> help "Produce term output only (for targets: geb)"
|
||||
)
|
||||
| otherwise ->
|
||||
pure False
|
||||
_compileOptimizationLevel <-
|
||||
option
|
||||
(fromIntegral <$> naturalNumberOpt)
|
||||
( short 'O'
|
||||
<> long "optimize"
|
||||
<> value defaultOptimizationLevel
|
||||
<> help ("Optimization level (default: " <> show defaultOptimizationLevel <> ")")
|
||||
)
|
||||
_compileInliningDepth <-
|
||||
option
|
||||
(fromIntegral <$> naturalNumberOpt)
|
||||
( long "inline"
|
||||
<> value defaultInliningDepth
|
||||
<> help ("Automatic inlining depth limit, logarithmic in the function size (default: " <> show defaultInliningDepth <> ")")
|
||||
)
|
||||
_compileTarget <- optCompileTarget supportedTargets
|
||||
_compileOutputFile <- optional parseGenericOutputFile
|
||||
|
@ -50,7 +50,9 @@ instance CanonicalProjection GlobalOptions Core.CoreOptions where
|
||||
project GlobalOptions {..} =
|
||||
Core.CoreOptions
|
||||
{ Core._optCheckCoverage = not _globalNoCoverage,
|
||||
Core._optUnrollLimit = _globalUnrollLimit
|
||||
Core._optUnrollLimit = _globalUnrollLimit,
|
||||
Core._optOptimizationLevel = defaultOptimizationLevel,
|
||||
Core._optInliningDepth = defaultInliningDepth
|
||||
}
|
||||
|
||||
defaultGlobalOptions :: GlobalOptions
|
||||
|
@ -474,7 +474,7 @@ foldS sig code a = snd <$> foldS' sig initialStackInfo code a
|
||||
foldS' :: forall r a. (Member (Error AsmError) r) => FoldSig StackInfo r a -> StackInfo -> Code -> a -> Sem r (StackInfo, a)
|
||||
foldS' sig si code acc = do
|
||||
(si', fs) <- recurseS' sig' si code
|
||||
a' <- compose fs acc
|
||||
a' <- compose' fs acc
|
||||
return (si', a')
|
||||
where
|
||||
sig' :: RecursorSig StackInfo r (a -> Sem r a)
|
||||
@ -486,21 +486,21 @@ foldS' sig si code acc = do
|
||||
return
|
||||
( \a -> do
|
||||
let a' = (sig ^. foldAdjust) a
|
||||
a1 <- compose br1 a'
|
||||
a2 <- compose br2 a'
|
||||
a1 <- compose' br1 a'
|
||||
a2 <- compose' br2 a'
|
||||
(sig ^. foldBranch) s cmd a1 a2 a
|
||||
),
|
||||
_recurseCase = \s cmd brs md ->
|
||||
return
|
||||
( \a -> do
|
||||
let a' = (sig ^. foldAdjust) a
|
||||
as <- mapM (`compose` a') brs
|
||||
as <- mapM (`compose'` a') brs
|
||||
ad <- case md of
|
||||
Just d -> Just <$> compose d a'
|
||||
Just d -> Just <$> compose' d a'
|
||||
Nothing -> return Nothing
|
||||
(sig ^. foldCase) s cmd as ad a
|
||||
)
|
||||
}
|
||||
|
||||
compose :: [a -> Sem r a] -> a -> Sem r a
|
||||
compose lst x = foldr (=<<) (return x) lst
|
||||
compose' :: [a -> Sem r a] -> a -> Sem r a
|
||||
compose' lst x = foldr (=<<) (return x) lst
|
||||
|
@ -25,3 +25,6 @@ createIdentDependencyInfo tab = createDependencyInfo graph startVertices
|
||||
|
||||
syms :: [Symbol]
|
||||
syms = map (^. identifierSymbol) (HashMap.elems (tab ^. infoIdentifiers))
|
||||
|
||||
recursiveIdents :: InfoTable -> HashSet Symbol
|
||||
recursiveIdents = nodesOnCycles . createIdentDependencyInfo
|
||||
|
@ -21,7 +21,13 @@ data TransformationId
|
||||
| CheckGeb
|
||||
| CheckExec
|
||||
| LetFolding
|
||||
| LambdaFolding
|
||||
| Inlining
|
||||
| FoldTypeSynonyms
|
||||
| OptPhaseEval
|
||||
| OptPhaseExec
|
||||
| OptPhaseGeb
|
||||
| OptPhaseMain
|
||||
deriving stock (Data, Bounded, Enum, Show)
|
||||
|
||||
data PipelineId
|
||||
@ -52,14 +58,14 @@ toTypecheckTransformations :: [TransformationId]
|
||||
toTypecheckTransformations = [MatchToCase]
|
||||
|
||||
toEvalTransformations :: [TransformationId]
|
||||
toEvalTransformations = [EtaExpandApps, MatchToCase, NatToPrimInt, IntToPrimInt, ConvertBuiltinTypes, LetFolding]
|
||||
toEvalTransformations = [EtaExpandApps, MatchToCase, NatToPrimInt, IntToPrimInt, ConvertBuiltinTypes, OptPhaseEval]
|
||||
|
||||
toStrippedTransformations :: [TransformationId]
|
||||
toStrippedTransformations =
|
||||
toEvalTransformations ++ [CheckExec, LambdaLetRecLifting, LetFolding, TopEtaExpand, MoveApps, RemoveTypeArgs]
|
||||
toEvalTransformations ++ [CheckExec, LambdaLetRecLifting, OptPhaseExec, TopEtaExpand, MoveApps, RemoveTypeArgs]
|
||||
|
||||
toGebTransformations :: [TransformationId]
|
||||
toGebTransformations = toEvalTransformations ++ [CheckGeb, LetRecLifting, LetFolding, UnrollRecursion, FoldTypeSynonyms, ComputeTypeInfo]
|
||||
toGebTransformations = toEvalTransformations ++ [CheckGeb, LetRecLifting, OptPhaseGeb, UnrollRecursion, FoldTypeSynonyms, ComputeTypeInfo]
|
||||
|
||||
pipeline :: PipelineId -> [TransformationId]
|
||||
pipeline = \case
|
||||
|
@ -79,7 +79,13 @@ transformationText = \case
|
||||
CheckGeb -> strCheckGeb
|
||||
CheckExec -> strCheckExec
|
||||
LetFolding -> strLetFolding
|
||||
LambdaFolding -> strLambdaFolding
|
||||
Inlining -> strInlining
|
||||
FoldTypeSynonyms -> strFoldTypeSynonyms
|
||||
OptPhaseEval -> strOptPhaseEval
|
||||
OptPhaseExec -> strOptPhaseExec
|
||||
OptPhaseGeb -> strOptPhaseGeb
|
||||
OptPhaseMain -> strOptPhaseMain
|
||||
|
||||
parsePipeline :: MonadParsec e Text m => m PipelineId
|
||||
parsePipeline = P.choice [symbol (pipelineText t) $> t | t <- allElements]
|
||||
@ -153,5 +159,23 @@ strCheckExec = "check-exec"
|
||||
strLetFolding :: Text
|
||||
strLetFolding = "let-folding"
|
||||
|
||||
strLambdaFolding :: Text
|
||||
strLambdaFolding = "lambda-folding"
|
||||
|
||||
strInlining :: Text
|
||||
strInlining = "inlining"
|
||||
|
||||
strFoldTypeSynonyms :: Text
|
||||
strFoldTypeSynonyms = "fold-type-synonyms"
|
||||
|
||||
strOptPhaseEval :: Text
|
||||
strOptPhaseEval = "opt-phase-eval"
|
||||
|
||||
strOptPhaseExec :: Text
|
||||
strOptPhaseExec = "opt-phase-exec"
|
||||
|
||||
strOptPhaseGeb :: Text
|
||||
strOptPhaseGeb = "opt-phase-geb"
|
||||
|
||||
strOptPhaseMain :: Text
|
||||
strOptPhaseMain = "opt-phase-main"
|
||||
|
@ -60,7 +60,10 @@ isImmediate = \case
|
||||
NVar {} -> True
|
||||
NIdt {} -> True
|
||||
NCst {} -> True
|
||||
_ -> False
|
||||
node@(NApp {}) ->
|
||||
let (_, args) = unfoldApps' node
|
||||
in all isType args
|
||||
node -> isType node
|
||||
|
||||
freeVarsSortedMany :: [Node] -> Set Var
|
||||
freeVarsSortedMany n = Set.fromList (n ^.. each . freeVars)
|
||||
@ -274,3 +277,7 @@ builtinOpArgTypes = \case
|
||||
OpSeq -> [mkDynamic', mkDynamic']
|
||||
OpTrace -> [mkDynamic']
|
||||
OpFail -> [mkTypeString']
|
||||
|
||||
checkDepth :: Int -> Node -> Bool
|
||||
checkDepth 0 _ = False
|
||||
checkDepth d node = all (checkDepth (d - 1)) (childrenNodes node)
|
||||
|
@ -5,7 +5,9 @@ import Juvix.Prelude
|
||||
|
||||
data CoreOptions = CoreOptions
|
||||
{ _optCheckCoverage :: Bool,
|
||||
_optUnrollLimit :: Int
|
||||
_optUnrollLimit :: Int,
|
||||
_optOptimizationLevel :: Int,
|
||||
_optInliningDepth :: Int
|
||||
}
|
||||
|
||||
makeLenses ''CoreOptions
|
||||
@ -14,12 +16,16 @@ defaultCoreOptions :: CoreOptions
|
||||
defaultCoreOptions =
|
||||
CoreOptions
|
||||
{ _optCheckCoverage = True,
|
||||
_optUnrollLimit = defaultUnrollLimit
|
||||
_optUnrollLimit = defaultUnrollLimit,
|
||||
_optOptimizationLevel = defaultOptimizationLevel,
|
||||
_optInliningDepth = defaultInliningDepth
|
||||
}
|
||||
|
||||
fromEntryPoint :: EntryPoint -> CoreOptions
|
||||
fromEntryPoint EntryPoint {..} =
|
||||
CoreOptions
|
||||
{ _optCheckCoverage = not _entryPointNoCoverage,
|
||||
_optUnrollLimit = _entryPointUnrollLimit
|
||||
_optUnrollLimit = _entryPointUnrollLimit,
|
||||
_optOptimizationLevel = _entryPointOptimizationLevel,
|
||||
_optInliningDepth = _entryPointInliningDepth
|
||||
}
|
||||
|
@ -26,7 +26,13 @@ import Juvix.Compiler.Core.Transformation.MatchToCase
|
||||
import Juvix.Compiler.Core.Transformation.MoveApps
|
||||
import Juvix.Compiler.Core.Transformation.NaiveMatchToCase qualified as Naive
|
||||
import Juvix.Compiler.Core.Transformation.NatToPrimInt
|
||||
import Juvix.Compiler.Core.Transformation.Optimize.Inlining
|
||||
import Juvix.Compiler.Core.Transformation.Optimize.LambdaFolding
|
||||
import Juvix.Compiler.Core.Transformation.Optimize.LetFolding
|
||||
import Juvix.Compiler.Core.Transformation.Optimize.Phase.Eval qualified as Phase.Eval
|
||||
import Juvix.Compiler.Core.Transformation.Optimize.Phase.Exec qualified as Phase.Exec
|
||||
import Juvix.Compiler.Core.Transformation.Optimize.Phase.Geb qualified as Phase.Geb
|
||||
import Juvix.Compiler.Core.Transformation.Optimize.Phase.Main qualified as Phase.Main
|
||||
import Juvix.Compiler.Core.Transformation.RemoveTypeArgs
|
||||
import Juvix.Compiler.Core.Transformation.TopEtaExpand
|
||||
import Juvix.Compiler.Core.Transformation.UnrollRecursion
|
||||
@ -54,4 +60,10 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts
|
||||
CheckGeb -> mapError (JuvixError @CoreError) . checkGeb
|
||||
CheckExec -> mapError (JuvixError @CoreError) . checkExec
|
||||
LetFolding -> return . letFolding
|
||||
LambdaFolding -> return . lambdaFolding
|
||||
Inlining -> inlining
|
||||
FoldTypeSynonyms -> return . foldTypeSynonyms
|
||||
OptPhaseEval -> Phase.Eval.optimize
|
||||
OptPhaseExec -> Phase.Exec.optimize
|
||||
OptPhaseGeb -> Phase.Geb.optimize
|
||||
OptPhaseMain -> Phase.Main.optimize
|
||||
|
@ -9,6 +9,7 @@ import Data.HashMap.Strict qualified as HashMap
|
||||
import Juvix.Compiler.Core.Data.InfoTable
|
||||
import Juvix.Compiler.Core.Data.InfoTableBuilder
|
||||
import Juvix.Compiler.Core.Language
|
||||
import Juvix.Compiler.Core.Options
|
||||
|
||||
mapIdentsM :: Monad m => (IdentifierInfo -> m IdentifierInfo) -> InfoTable -> m InfoTable
|
||||
mapIdentsM = overM infoIdentifiers . mapM
|
||||
@ -85,3 +86,13 @@ mapAllNodes f tab =
|
||||
|
||||
convertAxiom :: AxiomInfo -> AxiomInfo
|
||||
convertAxiom = over axiomType f
|
||||
|
||||
withOptimizationLevel :: Member (Reader CoreOptions) r => Int -> (InfoTable -> Sem r InfoTable) -> InfoTable -> Sem r InfoTable
|
||||
withOptimizationLevel n f tab = do
|
||||
l <- asks (^. optOptimizationLevel)
|
||||
if
|
||||
| l >= n -> f tab
|
||||
| otherwise -> return tab
|
||||
|
||||
withOptimizationLevel' :: Member (Reader CoreOptions) r => InfoTable -> Int -> (InfoTable -> Sem r InfoTable) -> Sem r InfoTable
|
||||
withOptimizationLevel' tab n f = withOptimizationLevel n f tab
|
||||
|
57
src/Juvix/Compiler/Core/Transformation/Optimize/Inlining.hs
Normal file
57
src/Juvix/Compiler/Core/Transformation/Optimize/Inlining.hs
Normal file
@ -0,0 +1,57 @@
|
||||
module Juvix.Compiler.Core.Transformation.Optimize.Inlining where
|
||||
|
||||
import Data.HashSet qualified as HashSet
|
||||
import Juvix.Compiler.Core.Data.IdentDependencyInfo
|
||||
import Juvix.Compiler.Core.Extra
|
||||
import Juvix.Compiler.Core.Options
|
||||
import Juvix.Compiler.Core.Transformation.Base
|
||||
|
||||
isInlineableLambda :: Int -> Node -> Bool
|
||||
isInlineableLambda inlineDepth node = case node of
|
||||
NLam {} ->
|
||||
checkDepth inlineDepth (snd (unfoldLambdas node))
|
||||
_ ->
|
||||
False
|
||||
|
||||
convertNode :: Int -> HashSet Symbol -> InfoTable -> Node -> Node
|
||||
convertNode inlineDepth recSyms tab = dmap go
|
||||
where
|
||||
go :: Node -> Node
|
||||
go node = case node of
|
||||
NApp {} ->
|
||||
let (h, args) = unfoldApps node
|
||||
in case h of
|
||||
NIdt Ident {..} ->
|
||||
case pi of
|
||||
Just InlineFullyApplied
|
||||
| length args >= argsNum ->
|
||||
mkApps def args
|
||||
Just (InlinePartiallyApplied k)
|
||||
| length args >= k ->
|
||||
mkApps def args
|
||||
Just InlineNever ->
|
||||
node
|
||||
_
|
||||
| not (HashSet.member _identSymbol recSyms)
|
||||
&& isInlineableLambda inlineDepth def
|
||||
&& length args >= argsNum ->
|
||||
mkApps def args
|
||||
_ ->
|
||||
node
|
||||
where
|
||||
ii = lookupIdentifierInfo tab _identSymbol
|
||||
pi = ii ^. identifierPragmas . pragmasInline
|
||||
argsNum = ii ^. identifierArgsNum
|
||||
def = lookupIdentifierNode tab _identSymbol
|
||||
_ ->
|
||||
node
|
||||
_ ->
|
||||
node
|
||||
|
||||
inlining' :: Int -> HashSet Symbol -> InfoTable -> InfoTable
|
||||
inlining' inliningDepth recSyms tab = mapT (const (convertNode inliningDepth recSyms tab)) tab
|
||||
|
||||
inlining :: Member (Reader CoreOptions) r => InfoTable -> Sem r InfoTable
|
||||
inlining tab = do
|
||||
d <- asks (^. optInliningDepth)
|
||||
return $ inlining' d (recursiveIdents tab) tab
|
@ -0,0 +1,45 @@
|
||||
-- An optimizing transformation that converts beta-redexes into let-expressions.
|
||||
--
|
||||
-- For example, transforms
|
||||
-- ```
|
||||
-- (\x \y x + y * x) a b
|
||||
-- ```
|
||||
-- to
|
||||
-- ```
|
||||
-- let x := a in let y := b in x + y * x
|
||||
-- ```
|
||||
module Juvix.Compiler.Core.Transformation.Optimize.LambdaFolding where
|
||||
|
||||
import Juvix.Compiler.Core.Extra
|
||||
import Juvix.Compiler.Core.Transformation.Base
|
||||
|
||||
convertNode :: Node -> Node
|
||||
convertNode = rmap go
|
||||
where
|
||||
go :: ([BinderChange] -> Node -> Node) -> Node -> Node
|
||||
go recur node = case node of
|
||||
NApp {} ->
|
||||
let (h, args) = unfoldApps' node
|
||||
(lams, body) = unfoldLambdas h
|
||||
in goLams [] lams args body
|
||||
where
|
||||
goLams :: [BinderChange] -> [LambdaLhs] -> [Node] -> Node -> Node
|
||||
goLams bcs lams args body =
|
||||
case (lams, args) of
|
||||
([], _) ->
|
||||
mkApps'
|
||||
(go (recur . (revAppend bcs)) body)
|
||||
(map (go (recur . (BCAdd (length bcs) :))) args)
|
||||
(lam : lams', arg : args') ->
|
||||
mkLet mempty bd' (go (recur . (BCAdd (length bcs) :)) arg) $
|
||||
goLams (BCKeep bd : bcs) lams' args' body
|
||||
where
|
||||
bd = lam ^. lambdaLhsBinder
|
||||
bd' = over binderType (go (recur . (revAppend bcs))) bd
|
||||
(_, []) ->
|
||||
go (recur . (revAppend bcs)) (reLambdas lams body)
|
||||
_ ->
|
||||
recur [] node
|
||||
|
||||
lambdaFolding :: InfoTable -> InfoTable
|
||||
lambdaFolding = mapAllNodes convertNode
|
@ -15,18 +15,22 @@ module Juvix.Compiler.Core.Transformation.Optimize.LetFolding where
|
||||
import Juvix.Compiler.Core.Extra
|
||||
import Juvix.Compiler.Core.Transformation.Base
|
||||
|
||||
convertNode :: Node -> Node
|
||||
convertNode = rmap go
|
||||
convertNode :: (Node -> Bool) -> Node -> Node
|
||||
convertNode isFoldable = rmap go
|
||||
where
|
||||
go :: ([BinderChange] -> Node -> Node) -> Node -> Node
|
||||
go recur = \case
|
||||
NLet Let {..}
|
||||
| isImmediate (_letItem ^. letItemValue) ->
|
||||
| isImmediate (_letItem ^. letItemValue)
|
||||
|| isFoldable (_letItem ^. letItemValue) ->
|
||||
go (recur . (mkBCRemove (_letItem ^. letItemBinder) val' :)) _letBody
|
||||
where
|
||||
val' = go recur (_letItem ^. letItemValue)
|
||||
node ->
|
||||
recur [] node
|
||||
|
||||
letFolding' :: (Node -> Bool) -> InfoTable -> InfoTable
|
||||
letFolding' isFoldable = mapAllNodes (convertNode isFoldable)
|
||||
|
||||
letFolding :: InfoTable -> InfoTable
|
||||
letFolding = mapAllNodes convertNode
|
||||
letFolding = letFolding' (const False)
|
||||
|
@ -0,0 +1,11 @@
|
||||
module Juvix.Compiler.Core.Transformation.Optimize.Phase.Eval where
|
||||
|
||||
import Juvix.Compiler.Core.Options
|
||||
import Juvix.Compiler.Core.Transformation.Base
|
||||
import Juvix.Compiler.Core.Transformation.Optimize.LambdaFolding
|
||||
import Juvix.Compiler.Core.Transformation.Optimize.LetFolding
|
||||
|
||||
optimize :: Member (Reader CoreOptions) r => InfoTable -> Sem r InfoTable
|
||||
optimize =
|
||||
withOptimizationLevel 1 $
|
||||
return . letFolding . lambdaFolding
|
@ -0,0 +1,16 @@
|
||||
module Juvix.Compiler.Core.Transformation.Optimize.Phase.Exec where
|
||||
|
||||
import Juvix.Compiler.Core.Options
|
||||
import Juvix.Compiler.Core.Transformation.Base
|
||||
import Juvix.Compiler.Core.Transformation.LambdaLetRecLifting
|
||||
import Juvix.Compiler.Core.Transformation.Optimize.LetFolding
|
||||
import Juvix.Compiler.Core.Transformation.Optimize.Phase.Main qualified as Main
|
||||
|
||||
optimize :: Member (Reader CoreOptions) r => InfoTable -> Sem r InfoTable
|
||||
optimize tab = do
|
||||
opts <- ask
|
||||
withOptimizationLevel' tab 1 $
|
||||
return
|
||||
. letFolding
|
||||
. lambdaLetRecLifting
|
||||
. Main.optimize' opts
|
@ -0,0 +1,8 @@
|
||||
module Juvix.Compiler.Core.Transformation.Optimize.Phase.Geb where
|
||||
|
||||
import Juvix.Compiler.Core.Options
|
||||
import Juvix.Compiler.Core.Transformation.Base
|
||||
import Juvix.Compiler.Core.Transformation.Optimize.Phase.Main qualified as Main
|
||||
|
||||
optimize :: Member (Reader CoreOptions) r => InfoTable -> Sem r InfoTable
|
||||
optimize = withOptimizationLevel 1 Main.optimize
|
@ -0,0 +1,24 @@
|
||||
module Juvix.Compiler.Core.Transformation.Optimize.Phase.Main where
|
||||
|
||||
import Juvix.Compiler.Core.Data.IdentDependencyInfo
|
||||
import Juvix.Compiler.Core.Options
|
||||
import Juvix.Compiler.Core.Transformation.Base
|
||||
import Juvix.Compiler.Core.Transformation.Optimize.Inlining
|
||||
import Juvix.Compiler.Core.Transformation.Optimize.LambdaFolding
|
||||
import Juvix.Compiler.Core.Transformation.Optimize.LetFolding
|
||||
|
||||
optimize' :: CoreOptions -> InfoTable -> InfoTable
|
||||
optimize' CoreOptions {..} tab =
|
||||
compose
|
||||
(4 * _optOptimizationLevel)
|
||||
( compose 2 (letFolding' (isInlineableLambda _optInliningDepth))
|
||||
. lambdaFolding
|
||||
. inlining' _optInliningDepth (recursiveIdents tab)
|
||||
)
|
||||
. letFolding
|
||||
$ tab
|
||||
|
||||
optimize :: Member (Reader CoreOptions) r => InfoTable -> Sem r InfoTable
|
||||
optimize tab = do
|
||||
opts <- ask
|
||||
return $ optimize' opts tab
|
@ -256,7 +256,11 @@ preFunctionDef f = do
|
||||
_identifierArgsNum = 0,
|
||||
_identifierIsExported = False,
|
||||
_identifierBuiltin = f ^. Internal.funDefBuiltin,
|
||||
_identifierPragmas = f ^. Internal.funDefPragmas
|
||||
_identifierPragmas =
|
||||
over
|
||||
pragmasInline
|
||||
(fmap (adjustPragmaInline (implicitParametersNum (f ^. Internal.funDefType))))
|
||||
(f ^. Internal.funDefPragmas)
|
||||
}
|
||||
case f ^. Internal.funDefBuiltin of
|
||||
Just b
|
||||
@ -279,6 +283,18 @@ preFunctionDef f = do
|
||||
">=" -> Str.natGe
|
||||
_ -> name
|
||||
|
||||
implicitParametersNum :: Internal.Expression -> Int
|
||||
implicitParametersNum = \case
|
||||
Internal.ExpressionFunction Internal.Function {..}
|
||||
| _functionLeft ^. Internal.paramImplicit == Implicit ->
|
||||
implicitParametersNum _functionRight + 1
|
||||
_ -> 0
|
||||
|
||||
adjustPragmaInline :: Int -> PragmaInline -> PragmaInline
|
||||
adjustPragmaInline n = \case
|
||||
InlinePartiallyApplied k -> InlinePartiallyApplied (k + n)
|
||||
x -> x
|
||||
|
||||
goFunctionDef ::
|
||||
forall r.
|
||||
(Members '[InfoTableBuilder, Reader Internal.InfoTable, Reader InternalTyped.TypesTable, State InternalTyped.FunctionsTable] r) =>
|
||||
|
@ -27,6 +27,8 @@ data EntryPoint = EntryPoint
|
||||
_entryPointTarget :: Target,
|
||||
_entryPointDebug :: Bool,
|
||||
_entryPointUnrollLimit :: Int,
|
||||
_entryPointOptimizationLevel :: Int,
|
||||
_entryPointInliningDepth :: Int,
|
||||
_entryPointGenericOptions :: GenericOptions,
|
||||
_entryPointModulePaths :: [Path Abs File]
|
||||
}
|
||||
@ -69,11 +71,19 @@ defaultEntryPointNoFile roots =
|
||||
_entryPointTarget = TargetCore,
|
||||
_entryPointDebug = False,
|
||||
_entryPointUnrollLimit = defaultUnrollLimit,
|
||||
_entryPointOptimizationLevel = defaultOptimizationLevel,
|
||||
_entryPointInliningDepth = defaultInliningDepth,
|
||||
_entryPointModulePaths = []
|
||||
}
|
||||
|
||||
defaultUnrollLimit :: Int
|
||||
defaultUnrollLimit = 140
|
||||
|
||||
defaultOptimizationLevel :: Int
|
||||
defaultOptimizationLevel = 1
|
||||
|
||||
defaultInliningDepth :: Int
|
||||
defaultInliningDepth = 2
|
||||
|
||||
mainModulePath :: Traversal' EntryPoint (Path Abs File)
|
||||
mainModulePath = entryPointModulePaths . _head
|
||||
|
@ -65,3 +65,11 @@ buildSCCs = Graph.stronglyConnComp . (^. depInfoEdgeList)
|
||||
|
||||
isCyclic :: Ord n => DependencyInfo n -> Bool
|
||||
isCyclic = any (\case CyclicSCC _ -> True; _ -> False) . buildSCCs
|
||||
|
||||
nodesOnCycles :: forall n. (Hashable n, Ord n) => DependencyInfo n -> HashSet n
|
||||
nodesOnCycles = foldr go mempty . buildSCCs
|
||||
where
|
||||
go :: SCC n -> HashSet n -> HashSet n
|
||||
go x acc = case x of
|
||||
CyclicSCC ns -> foldr HashSet.insert acc ns
|
||||
_ -> acc
|
||||
|
@ -199,6 +199,14 @@ traverseM ::
|
||||
f (m a2)
|
||||
traverseM f = fmap join . traverse f
|
||||
|
||||
composeM :: Monad m => Int -> (a -> m a) -> a -> m a
|
||||
composeM 0 _ a = return a
|
||||
composeM n f a = composeM (n - 1) f a >>= f
|
||||
|
||||
compose :: Int -> (a -> a) -> a -> a
|
||||
compose 0 _ a = a
|
||||
compose n f a = f (compose (n - 1) f a)
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
mapReader :: Member (Reader e1) r => (e1 -> e2) -> Sem (Reader e2 ': r) a -> Sem r a
|
||||
|
@ -316,8 +316,13 @@ tests =
|
||||
$(mkRelFile "test051.juvix")
|
||||
$(mkRelFile "out/test051.out"),
|
||||
posTest
|
||||
"Test052: Mutually recursive types, simple lambda calculus"
|
||||
"Test052: Simple lambda calculus"
|
||||
$(mkRelDir ".")
|
||||
$(mkRelFile "test052.juvix")
|
||||
$(mkRelFile "out/test052.out")
|
||||
$(mkRelFile "out/test052.out"),
|
||||
posTest
|
||||
"Test053: Inlining"
|
||||
$(mkRelDir ".")
|
||||
$(mkRelFile "test053.juvix")
|
||||
$(mkRelFile "out/test053.out")
|
||||
]
|
||||
|
1
tests/Compilation/positive/out/test053.out
Normal file
1
tests/Compilation/positive/out/test053.out
Normal file
@ -0,0 +1 @@
|
||||
16
|
@ -1,4 +1,4 @@
|
||||
--- This module defines a simple lambda calculus and an evaluator.
|
||||
--- Simple lambda claculus
|
||||
module test052;
|
||||
|
||||
open import Stdlib.Prelude;
|
||||
|
23
tests/Compilation/positive/test053.juvix
Normal file
23
tests/Compilation/positive/test053.juvix
Normal file
@ -0,0 +1,23 @@
|
||||
-- Inlining
|
||||
module test052;
|
||||
|
||||
open import Stdlib.Prelude;
|
||||
|
||||
{-# inline: 2 #-}
|
||||
mycompose : {A B C : Type} -> (B -> C) -> (A -> B) -> A -> C;
|
||||
mycompose f g x := f (g x);
|
||||
|
||||
{-# inline: true #-}
|
||||
myconst : {A B : Type} -> A -> B -> A;
|
||||
myconst x _ := x;
|
||||
|
||||
{-# inline: 1 #-}
|
||||
myflip : {A B C : Type} -> (A -> B -> C) -> B -> A -> C;
|
||||
myflip f b a := f a b;
|
||||
|
||||
main : Nat;
|
||||
main :=
|
||||
let f : Nat -> Nat := mycompose λ{x := x + 1} λ{x := x * 2};
|
||||
g : Nat -> Nat -> Nat := myflip myconst;
|
||||
in
|
||||
f 3 + g 7 9;
|
Loading…
Reference in New Issue
Block a user