mirror of
https://github.com/anoma/juvix.git
synced 2025-01-06 23:56:20 +03:00
Inlining (#2036)
* Closes #1989 * Adds optimization phases to the pipline (specified by `opt-phase-eval`, `opt-phase-exec` and `opt-phase-geb` transformations). * Adds the `-O` option to the `compile` command to specify the optimization level. * Functions can be declared for inlining with the `inline` pragma: ``` {-# inline: true #-} const : {A B : Type} -> A -> B -> A; const x _ := x; ``` By default, the function is inlined only if it's fully applied. One can specify that a function (partially) applied to at least `n` explicit arguments should be inlined. ``` {-# inline: 2 #-} compose : {A B C : Type} -> (B -> C) -> (A -> B) -> A -> C; compose f g x := f (g x); ``` Then `compose f g` will be inlined, even though it's not fully applied. But `compose f` won't be inlined. * Non-recursive fully applied functions are automatically inlined if the height of the body term does not exceed the inlining depth limit, which can be specified with the `--inline` option to the `compile` command. * The pragma `inline: false` disables automatic inlining on a per-function basis.
This commit is contained in:
parent
ebeef381e6
commit
8aa54ecc28
@ -23,7 +23,9 @@ getEntry PipelineArg {..} = do
|
|||||||
return $
|
return $
|
||||||
ep
|
ep
|
||||||
{ _entryPointTarget = getTarget (_pipelineArgOptions ^. compileTarget),
|
{ _entryPointTarget = getTarget (_pipelineArgOptions ^. compileTarget),
|
||||||
_entryPointDebug = _pipelineArgOptions ^. compileDebug
|
_entryPointDebug = _pipelineArgOptions ^. compileDebug,
|
||||||
|
_entryPointOptimizationLevel = _pipelineArgOptions ^. compileOptimizationLevel,
|
||||||
|
_entryPointInliningDepth = _pipelineArgOptions ^. compileInliningDepth
|
||||||
}
|
}
|
||||||
where
|
where
|
||||||
getTarget :: CompileTarget -> Backend.Target
|
getTarget :: CompileTarget -> Backend.Target
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
module Commands.Extra.Compile.Options where
|
module Commands.Extra.Compile.Options where
|
||||||
|
|
||||||
import CommonOptions hiding (show)
|
import CommonOptions hiding (show)
|
||||||
|
import Juvix.Compiler.Pipeline.EntryPoint
|
||||||
import Prelude (Show (show))
|
import Prelude (Show (show))
|
||||||
|
|
||||||
data CompileTarget
|
data CompileTarget
|
||||||
@ -9,7 +10,7 @@ data CompileTarget
|
|||||||
| TargetGeb
|
| TargetGeb
|
||||||
| TargetCore
|
| TargetCore
|
||||||
| TargetAsm
|
| TargetAsm
|
||||||
deriving stock (Data, Bounded, Enum)
|
deriving stock (Eq, Data, Bounded, Enum)
|
||||||
|
|
||||||
instance Show CompileTarget where
|
instance Show CompileTarget where
|
||||||
show = \case
|
show = \case
|
||||||
@ -27,7 +28,9 @@ data CompileOptions = CompileOptions
|
|||||||
_compileTerm :: Bool,
|
_compileTerm :: Bool,
|
||||||
_compileOutputFile :: Maybe (AppPath File),
|
_compileOutputFile :: Maybe (AppPath File),
|
||||||
_compileTarget :: CompileTarget,
|
_compileTarget :: CompileTarget,
|
||||||
_compileInputFile :: AppPath File
|
_compileInputFile :: AppPath File,
|
||||||
|
_compileOptimizationLevel :: Int,
|
||||||
|
_compileInliningDepth :: Int
|
||||||
}
|
}
|
||||||
deriving stock (Data)
|
deriving stock (Data)
|
||||||
|
|
||||||
@ -68,10 +71,29 @@ parseCompileOptions supportedTargets parseInputFile = do
|
|||||||
<> help "Produce assembly output only (for targets: wasm32-wasi, native)"
|
<> help "Produce assembly output only (for targets: wasm32-wasi, native)"
|
||||||
)
|
)
|
||||||
_compileTerm <-
|
_compileTerm <-
|
||||||
switch
|
if
|
||||||
( short 'G'
|
| elem TargetGeb supportedTargets ->
|
||||||
<> long "only-term"
|
switch
|
||||||
<> help "Produce term output only (for targets: geb)"
|
( short 'G'
|
||||||
|
<> long "only-term"
|
||||||
|
<> help "Produce term output only (for targets: geb)"
|
||||||
|
)
|
||||||
|
| otherwise ->
|
||||||
|
pure False
|
||||||
|
_compileOptimizationLevel <-
|
||||||
|
option
|
||||||
|
(fromIntegral <$> naturalNumberOpt)
|
||||||
|
( short 'O'
|
||||||
|
<> long "optimize"
|
||||||
|
<> value defaultOptimizationLevel
|
||||||
|
<> help ("Optimization level (default: " <> show defaultOptimizationLevel <> ")")
|
||||||
|
)
|
||||||
|
_compileInliningDepth <-
|
||||||
|
option
|
||||||
|
(fromIntegral <$> naturalNumberOpt)
|
||||||
|
( long "inline"
|
||||||
|
<> value defaultInliningDepth
|
||||||
|
<> help ("Automatic inlining depth limit, logarithmic in the function size (default: " <> show defaultInliningDepth <> ")")
|
||||||
)
|
)
|
||||||
_compileTarget <- optCompileTarget supportedTargets
|
_compileTarget <- optCompileTarget supportedTargets
|
||||||
_compileOutputFile <- optional parseGenericOutputFile
|
_compileOutputFile <- optional parseGenericOutputFile
|
||||||
|
@ -50,7 +50,9 @@ instance CanonicalProjection GlobalOptions Core.CoreOptions where
|
|||||||
project GlobalOptions {..} =
|
project GlobalOptions {..} =
|
||||||
Core.CoreOptions
|
Core.CoreOptions
|
||||||
{ Core._optCheckCoverage = not _globalNoCoverage,
|
{ Core._optCheckCoverage = not _globalNoCoverage,
|
||||||
Core._optUnrollLimit = _globalUnrollLimit
|
Core._optUnrollLimit = _globalUnrollLimit,
|
||||||
|
Core._optOptimizationLevel = defaultOptimizationLevel,
|
||||||
|
Core._optInliningDepth = defaultInliningDepth
|
||||||
}
|
}
|
||||||
|
|
||||||
defaultGlobalOptions :: GlobalOptions
|
defaultGlobalOptions :: GlobalOptions
|
||||||
|
@ -474,7 +474,7 @@ foldS sig code a = snd <$> foldS' sig initialStackInfo code a
|
|||||||
foldS' :: forall r a. (Member (Error AsmError) r) => FoldSig StackInfo r a -> StackInfo -> Code -> a -> Sem r (StackInfo, a)
|
foldS' :: forall r a. (Member (Error AsmError) r) => FoldSig StackInfo r a -> StackInfo -> Code -> a -> Sem r (StackInfo, a)
|
||||||
foldS' sig si code acc = do
|
foldS' sig si code acc = do
|
||||||
(si', fs) <- recurseS' sig' si code
|
(si', fs) <- recurseS' sig' si code
|
||||||
a' <- compose fs acc
|
a' <- compose' fs acc
|
||||||
return (si', a')
|
return (si', a')
|
||||||
where
|
where
|
||||||
sig' :: RecursorSig StackInfo r (a -> Sem r a)
|
sig' :: RecursorSig StackInfo r (a -> Sem r a)
|
||||||
@ -486,21 +486,21 @@ foldS' sig si code acc = do
|
|||||||
return
|
return
|
||||||
( \a -> do
|
( \a -> do
|
||||||
let a' = (sig ^. foldAdjust) a
|
let a' = (sig ^. foldAdjust) a
|
||||||
a1 <- compose br1 a'
|
a1 <- compose' br1 a'
|
||||||
a2 <- compose br2 a'
|
a2 <- compose' br2 a'
|
||||||
(sig ^. foldBranch) s cmd a1 a2 a
|
(sig ^. foldBranch) s cmd a1 a2 a
|
||||||
),
|
),
|
||||||
_recurseCase = \s cmd brs md ->
|
_recurseCase = \s cmd brs md ->
|
||||||
return
|
return
|
||||||
( \a -> do
|
( \a -> do
|
||||||
let a' = (sig ^. foldAdjust) a
|
let a' = (sig ^. foldAdjust) a
|
||||||
as <- mapM (`compose` a') brs
|
as <- mapM (`compose'` a') brs
|
||||||
ad <- case md of
|
ad <- case md of
|
||||||
Just d -> Just <$> compose d a'
|
Just d -> Just <$> compose' d a'
|
||||||
Nothing -> return Nothing
|
Nothing -> return Nothing
|
||||||
(sig ^. foldCase) s cmd as ad a
|
(sig ^. foldCase) s cmd as ad a
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
compose :: [a -> Sem r a] -> a -> Sem r a
|
compose' :: [a -> Sem r a] -> a -> Sem r a
|
||||||
compose lst x = foldr (=<<) (return x) lst
|
compose' lst x = foldr (=<<) (return x) lst
|
||||||
|
@ -25,3 +25,6 @@ createIdentDependencyInfo tab = createDependencyInfo graph startVertices
|
|||||||
|
|
||||||
syms :: [Symbol]
|
syms :: [Symbol]
|
||||||
syms = map (^. identifierSymbol) (HashMap.elems (tab ^. infoIdentifiers))
|
syms = map (^. identifierSymbol) (HashMap.elems (tab ^. infoIdentifiers))
|
||||||
|
|
||||||
|
recursiveIdents :: InfoTable -> HashSet Symbol
|
||||||
|
recursiveIdents = nodesOnCycles . createIdentDependencyInfo
|
||||||
|
@ -21,7 +21,13 @@ data TransformationId
|
|||||||
| CheckGeb
|
| CheckGeb
|
||||||
| CheckExec
|
| CheckExec
|
||||||
| LetFolding
|
| LetFolding
|
||||||
|
| LambdaFolding
|
||||||
|
| Inlining
|
||||||
| FoldTypeSynonyms
|
| FoldTypeSynonyms
|
||||||
|
| OptPhaseEval
|
||||||
|
| OptPhaseExec
|
||||||
|
| OptPhaseGeb
|
||||||
|
| OptPhaseMain
|
||||||
deriving stock (Data, Bounded, Enum, Show)
|
deriving stock (Data, Bounded, Enum, Show)
|
||||||
|
|
||||||
data PipelineId
|
data PipelineId
|
||||||
@ -52,14 +58,14 @@ toTypecheckTransformations :: [TransformationId]
|
|||||||
toTypecheckTransformations = [MatchToCase]
|
toTypecheckTransformations = [MatchToCase]
|
||||||
|
|
||||||
toEvalTransformations :: [TransformationId]
|
toEvalTransformations :: [TransformationId]
|
||||||
toEvalTransformations = [EtaExpandApps, MatchToCase, NatToPrimInt, IntToPrimInt, ConvertBuiltinTypes, LetFolding]
|
toEvalTransformations = [EtaExpandApps, MatchToCase, NatToPrimInt, IntToPrimInt, ConvertBuiltinTypes, OptPhaseEval]
|
||||||
|
|
||||||
toStrippedTransformations :: [TransformationId]
|
toStrippedTransformations :: [TransformationId]
|
||||||
toStrippedTransformations =
|
toStrippedTransformations =
|
||||||
toEvalTransformations ++ [CheckExec, LambdaLetRecLifting, LetFolding, TopEtaExpand, MoveApps, RemoveTypeArgs]
|
toEvalTransformations ++ [CheckExec, LambdaLetRecLifting, OptPhaseExec, TopEtaExpand, MoveApps, RemoveTypeArgs]
|
||||||
|
|
||||||
toGebTransformations :: [TransformationId]
|
toGebTransformations :: [TransformationId]
|
||||||
toGebTransformations = toEvalTransformations ++ [CheckGeb, LetRecLifting, LetFolding, UnrollRecursion, FoldTypeSynonyms, ComputeTypeInfo]
|
toGebTransformations = toEvalTransformations ++ [CheckGeb, LetRecLifting, OptPhaseGeb, UnrollRecursion, FoldTypeSynonyms, ComputeTypeInfo]
|
||||||
|
|
||||||
pipeline :: PipelineId -> [TransformationId]
|
pipeline :: PipelineId -> [TransformationId]
|
||||||
pipeline = \case
|
pipeline = \case
|
||||||
|
@ -79,7 +79,13 @@ transformationText = \case
|
|||||||
CheckGeb -> strCheckGeb
|
CheckGeb -> strCheckGeb
|
||||||
CheckExec -> strCheckExec
|
CheckExec -> strCheckExec
|
||||||
LetFolding -> strLetFolding
|
LetFolding -> strLetFolding
|
||||||
|
LambdaFolding -> strLambdaFolding
|
||||||
|
Inlining -> strInlining
|
||||||
FoldTypeSynonyms -> strFoldTypeSynonyms
|
FoldTypeSynonyms -> strFoldTypeSynonyms
|
||||||
|
OptPhaseEval -> strOptPhaseEval
|
||||||
|
OptPhaseExec -> strOptPhaseExec
|
||||||
|
OptPhaseGeb -> strOptPhaseGeb
|
||||||
|
OptPhaseMain -> strOptPhaseMain
|
||||||
|
|
||||||
parsePipeline :: MonadParsec e Text m => m PipelineId
|
parsePipeline :: MonadParsec e Text m => m PipelineId
|
||||||
parsePipeline = P.choice [symbol (pipelineText t) $> t | t <- allElements]
|
parsePipeline = P.choice [symbol (pipelineText t) $> t | t <- allElements]
|
||||||
@ -153,5 +159,23 @@ strCheckExec = "check-exec"
|
|||||||
strLetFolding :: Text
|
strLetFolding :: Text
|
||||||
strLetFolding = "let-folding"
|
strLetFolding = "let-folding"
|
||||||
|
|
||||||
|
strLambdaFolding :: Text
|
||||||
|
strLambdaFolding = "lambda-folding"
|
||||||
|
|
||||||
|
strInlining :: Text
|
||||||
|
strInlining = "inlining"
|
||||||
|
|
||||||
strFoldTypeSynonyms :: Text
|
strFoldTypeSynonyms :: Text
|
||||||
strFoldTypeSynonyms = "fold-type-synonyms"
|
strFoldTypeSynonyms = "fold-type-synonyms"
|
||||||
|
|
||||||
|
strOptPhaseEval :: Text
|
||||||
|
strOptPhaseEval = "opt-phase-eval"
|
||||||
|
|
||||||
|
strOptPhaseExec :: Text
|
||||||
|
strOptPhaseExec = "opt-phase-exec"
|
||||||
|
|
||||||
|
strOptPhaseGeb :: Text
|
||||||
|
strOptPhaseGeb = "opt-phase-geb"
|
||||||
|
|
||||||
|
strOptPhaseMain :: Text
|
||||||
|
strOptPhaseMain = "opt-phase-main"
|
||||||
|
@ -60,7 +60,10 @@ isImmediate = \case
|
|||||||
NVar {} -> True
|
NVar {} -> True
|
||||||
NIdt {} -> True
|
NIdt {} -> True
|
||||||
NCst {} -> True
|
NCst {} -> True
|
||||||
_ -> False
|
node@(NApp {}) ->
|
||||||
|
let (_, args) = unfoldApps' node
|
||||||
|
in all isType args
|
||||||
|
node -> isType node
|
||||||
|
|
||||||
freeVarsSortedMany :: [Node] -> Set Var
|
freeVarsSortedMany :: [Node] -> Set Var
|
||||||
freeVarsSortedMany n = Set.fromList (n ^.. each . freeVars)
|
freeVarsSortedMany n = Set.fromList (n ^.. each . freeVars)
|
||||||
@ -274,3 +277,7 @@ builtinOpArgTypes = \case
|
|||||||
OpSeq -> [mkDynamic', mkDynamic']
|
OpSeq -> [mkDynamic', mkDynamic']
|
||||||
OpTrace -> [mkDynamic']
|
OpTrace -> [mkDynamic']
|
||||||
OpFail -> [mkTypeString']
|
OpFail -> [mkTypeString']
|
||||||
|
|
||||||
|
checkDepth :: Int -> Node -> Bool
|
||||||
|
checkDepth 0 _ = False
|
||||||
|
checkDepth d node = all (checkDepth (d - 1)) (childrenNodes node)
|
||||||
|
@ -5,7 +5,9 @@ import Juvix.Prelude
|
|||||||
|
|
||||||
data CoreOptions = CoreOptions
|
data CoreOptions = CoreOptions
|
||||||
{ _optCheckCoverage :: Bool,
|
{ _optCheckCoverage :: Bool,
|
||||||
_optUnrollLimit :: Int
|
_optUnrollLimit :: Int,
|
||||||
|
_optOptimizationLevel :: Int,
|
||||||
|
_optInliningDepth :: Int
|
||||||
}
|
}
|
||||||
|
|
||||||
makeLenses ''CoreOptions
|
makeLenses ''CoreOptions
|
||||||
@ -14,12 +16,16 @@ defaultCoreOptions :: CoreOptions
|
|||||||
defaultCoreOptions =
|
defaultCoreOptions =
|
||||||
CoreOptions
|
CoreOptions
|
||||||
{ _optCheckCoverage = True,
|
{ _optCheckCoverage = True,
|
||||||
_optUnrollLimit = defaultUnrollLimit
|
_optUnrollLimit = defaultUnrollLimit,
|
||||||
|
_optOptimizationLevel = defaultOptimizationLevel,
|
||||||
|
_optInliningDepth = defaultInliningDepth
|
||||||
}
|
}
|
||||||
|
|
||||||
fromEntryPoint :: EntryPoint -> CoreOptions
|
fromEntryPoint :: EntryPoint -> CoreOptions
|
||||||
fromEntryPoint EntryPoint {..} =
|
fromEntryPoint EntryPoint {..} =
|
||||||
CoreOptions
|
CoreOptions
|
||||||
{ _optCheckCoverage = not _entryPointNoCoverage,
|
{ _optCheckCoverage = not _entryPointNoCoverage,
|
||||||
_optUnrollLimit = _entryPointUnrollLimit
|
_optUnrollLimit = _entryPointUnrollLimit,
|
||||||
|
_optOptimizationLevel = _entryPointOptimizationLevel,
|
||||||
|
_optInliningDepth = _entryPointInliningDepth
|
||||||
}
|
}
|
||||||
|
@ -26,7 +26,13 @@ import Juvix.Compiler.Core.Transformation.MatchToCase
|
|||||||
import Juvix.Compiler.Core.Transformation.MoveApps
|
import Juvix.Compiler.Core.Transformation.MoveApps
|
||||||
import Juvix.Compiler.Core.Transformation.NaiveMatchToCase qualified as Naive
|
import Juvix.Compiler.Core.Transformation.NaiveMatchToCase qualified as Naive
|
||||||
import Juvix.Compiler.Core.Transformation.NatToPrimInt
|
import Juvix.Compiler.Core.Transformation.NatToPrimInt
|
||||||
|
import Juvix.Compiler.Core.Transformation.Optimize.Inlining
|
||||||
|
import Juvix.Compiler.Core.Transformation.Optimize.LambdaFolding
|
||||||
import Juvix.Compiler.Core.Transformation.Optimize.LetFolding
|
import Juvix.Compiler.Core.Transformation.Optimize.LetFolding
|
||||||
|
import Juvix.Compiler.Core.Transformation.Optimize.Phase.Eval qualified as Phase.Eval
|
||||||
|
import Juvix.Compiler.Core.Transformation.Optimize.Phase.Exec qualified as Phase.Exec
|
||||||
|
import Juvix.Compiler.Core.Transformation.Optimize.Phase.Geb qualified as Phase.Geb
|
||||||
|
import Juvix.Compiler.Core.Transformation.Optimize.Phase.Main qualified as Phase.Main
|
||||||
import Juvix.Compiler.Core.Transformation.RemoveTypeArgs
|
import Juvix.Compiler.Core.Transformation.RemoveTypeArgs
|
||||||
import Juvix.Compiler.Core.Transformation.TopEtaExpand
|
import Juvix.Compiler.Core.Transformation.TopEtaExpand
|
||||||
import Juvix.Compiler.Core.Transformation.UnrollRecursion
|
import Juvix.Compiler.Core.Transformation.UnrollRecursion
|
||||||
@ -54,4 +60,10 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts
|
|||||||
CheckGeb -> mapError (JuvixError @CoreError) . checkGeb
|
CheckGeb -> mapError (JuvixError @CoreError) . checkGeb
|
||||||
CheckExec -> mapError (JuvixError @CoreError) . checkExec
|
CheckExec -> mapError (JuvixError @CoreError) . checkExec
|
||||||
LetFolding -> return . letFolding
|
LetFolding -> return . letFolding
|
||||||
|
LambdaFolding -> return . lambdaFolding
|
||||||
|
Inlining -> inlining
|
||||||
FoldTypeSynonyms -> return . foldTypeSynonyms
|
FoldTypeSynonyms -> return . foldTypeSynonyms
|
||||||
|
OptPhaseEval -> Phase.Eval.optimize
|
||||||
|
OptPhaseExec -> Phase.Exec.optimize
|
||||||
|
OptPhaseGeb -> Phase.Geb.optimize
|
||||||
|
OptPhaseMain -> Phase.Main.optimize
|
||||||
|
@ -9,6 +9,7 @@ import Data.HashMap.Strict qualified as HashMap
|
|||||||
import Juvix.Compiler.Core.Data.InfoTable
|
import Juvix.Compiler.Core.Data.InfoTable
|
||||||
import Juvix.Compiler.Core.Data.InfoTableBuilder
|
import Juvix.Compiler.Core.Data.InfoTableBuilder
|
||||||
import Juvix.Compiler.Core.Language
|
import Juvix.Compiler.Core.Language
|
||||||
|
import Juvix.Compiler.Core.Options
|
||||||
|
|
||||||
mapIdentsM :: Monad m => (IdentifierInfo -> m IdentifierInfo) -> InfoTable -> m InfoTable
|
mapIdentsM :: Monad m => (IdentifierInfo -> m IdentifierInfo) -> InfoTable -> m InfoTable
|
||||||
mapIdentsM = overM infoIdentifiers . mapM
|
mapIdentsM = overM infoIdentifiers . mapM
|
||||||
@ -85,3 +86,13 @@ mapAllNodes f tab =
|
|||||||
|
|
||||||
convertAxiom :: AxiomInfo -> AxiomInfo
|
convertAxiom :: AxiomInfo -> AxiomInfo
|
||||||
convertAxiom = over axiomType f
|
convertAxiom = over axiomType f
|
||||||
|
|
||||||
|
withOptimizationLevel :: Member (Reader CoreOptions) r => Int -> (InfoTable -> Sem r InfoTable) -> InfoTable -> Sem r InfoTable
|
||||||
|
withOptimizationLevel n f tab = do
|
||||||
|
l <- asks (^. optOptimizationLevel)
|
||||||
|
if
|
||||||
|
| l >= n -> f tab
|
||||||
|
| otherwise -> return tab
|
||||||
|
|
||||||
|
withOptimizationLevel' :: Member (Reader CoreOptions) r => InfoTable -> Int -> (InfoTable -> Sem r InfoTable) -> Sem r InfoTable
|
||||||
|
withOptimizationLevel' tab n f = withOptimizationLevel n f tab
|
||||||
|
57
src/Juvix/Compiler/Core/Transformation/Optimize/Inlining.hs
Normal file
57
src/Juvix/Compiler/Core/Transformation/Optimize/Inlining.hs
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
module Juvix.Compiler.Core.Transformation.Optimize.Inlining where
|
||||||
|
|
||||||
|
import Data.HashSet qualified as HashSet
|
||||||
|
import Juvix.Compiler.Core.Data.IdentDependencyInfo
|
||||||
|
import Juvix.Compiler.Core.Extra
|
||||||
|
import Juvix.Compiler.Core.Options
|
||||||
|
import Juvix.Compiler.Core.Transformation.Base
|
||||||
|
|
||||||
|
isInlineableLambda :: Int -> Node -> Bool
|
||||||
|
isInlineableLambda inlineDepth node = case node of
|
||||||
|
NLam {} ->
|
||||||
|
checkDepth inlineDepth (snd (unfoldLambdas node))
|
||||||
|
_ ->
|
||||||
|
False
|
||||||
|
|
||||||
|
convertNode :: Int -> HashSet Symbol -> InfoTable -> Node -> Node
|
||||||
|
convertNode inlineDepth recSyms tab = dmap go
|
||||||
|
where
|
||||||
|
go :: Node -> Node
|
||||||
|
go node = case node of
|
||||||
|
NApp {} ->
|
||||||
|
let (h, args) = unfoldApps node
|
||||||
|
in case h of
|
||||||
|
NIdt Ident {..} ->
|
||||||
|
case pi of
|
||||||
|
Just InlineFullyApplied
|
||||||
|
| length args >= argsNum ->
|
||||||
|
mkApps def args
|
||||||
|
Just (InlinePartiallyApplied k)
|
||||||
|
| length args >= k ->
|
||||||
|
mkApps def args
|
||||||
|
Just InlineNever ->
|
||||||
|
node
|
||||||
|
_
|
||||||
|
| not (HashSet.member _identSymbol recSyms)
|
||||||
|
&& isInlineableLambda inlineDepth def
|
||||||
|
&& length args >= argsNum ->
|
||||||
|
mkApps def args
|
||||||
|
_ ->
|
||||||
|
node
|
||||||
|
where
|
||||||
|
ii = lookupIdentifierInfo tab _identSymbol
|
||||||
|
pi = ii ^. identifierPragmas . pragmasInline
|
||||||
|
argsNum = ii ^. identifierArgsNum
|
||||||
|
def = lookupIdentifierNode tab _identSymbol
|
||||||
|
_ ->
|
||||||
|
node
|
||||||
|
_ ->
|
||||||
|
node
|
||||||
|
|
||||||
|
inlining' :: Int -> HashSet Symbol -> InfoTable -> InfoTable
|
||||||
|
inlining' inliningDepth recSyms tab = mapT (const (convertNode inliningDepth recSyms tab)) tab
|
||||||
|
|
||||||
|
inlining :: Member (Reader CoreOptions) r => InfoTable -> Sem r InfoTable
|
||||||
|
inlining tab = do
|
||||||
|
d <- asks (^. optInliningDepth)
|
||||||
|
return $ inlining' d (recursiveIdents tab) tab
|
@ -0,0 +1,45 @@
|
|||||||
|
-- An optimizing transformation that converts beta-redexes into let-expressions.
|
||||||
|
--
|
||||||
|
-- For example, transforms
|
||||||
|
-- ```
|
||||||
|
-- (\x \y x + y * x) a b
|
||||||
|
-- ```
|
||||||
|
-- to
|
||||||
|
-- ```
|
||||||
|
-- let x := a in let y := b in x + y * x
|
||||||
|
-- ```
|
||||||
|
module Juvix.Compiler.Core.Transformation.Optimize.LambdaFolding where
|
||||||
|
|
||||||
|
import Juvix.Compiler.Core.Extra
|
||||||
|
import Juvix.Compiler.Core.Transformation.Base
|
||||||
|
|
||||||
|
convertNode :: Node -> Node
|
||||||
|
convertNode = rmap go
|
||||||
|
where
|
||||||
|
go :: ([BinderChange] -> Node -> Node) -> Node -> Node
|
||||||
|
go recur node = case node of
|
||||||
|
NApp {} ->
|
||||||
|
let (h, args) = unfoldApps' node
|
||||||
|
(lams, body) = unfoldLambdas h
|
||||||
|
in goLams [] lams args body
|
||||||
|
where
|
||||||
|
goLams :: [BinderChange] -> [LambdaLhs] -> [Node] -> Node -> Node
|
||||||
|
goLams bcs lams args body =
|
||||||
|
case (lams, args) of
|
||||||
|
([], _) ->
|
||||||
|
mkApps'
|
||||||
|
(go (recur . (revAppend bcs)) body)
|
||||||
|
(map (go (recur . (BCAdd (length bcs) :))) args)
|
||||||
|
(lam : lams', arg : args') ->
|
||||||
|
mkLet mempty bd' (go (recur . (BCAdd (length bcs) :)) arg) $
|
||||||
|
goLams (BCKeep bd : bcs) lams' args' body
|
||||||
|
where
|
||||||
|
bd = lam ^. lambdaLhsBinder
|
||||||
|
bd' = over binderType (go (recur . (revAppend bcs))) bd
|
||||||
|
(_, []) ->
|
||||||
|
go (recur . (revAppend bcs)) (reLambdas lams body)
|
||||||
|
_ ->
|
||||||
|
recur [] node
|
||||||
|
|
||||||
|
lambdaFolding :: InfoTable -> InfoTable
|
||||||
|
lambdaFolding = mapAllNodes convertNode
|
@ -15,18 +15,22 @@ module Juvix.Compiler.Core.Transformation.Optimize.LetFolding where
|
|||||||
import Juvix.Compiler.Core.Extra
|
import Juvix.Compiler.Core.Extra
|
||||||
import Juvix.Compiler.Core.Transformation.Base
|
import Juvix.Compiler.Core.Transformation.Base
|
||||||
|
|
||||||
convertNode :: Node -> Node
|
convertNode :: (Node -> Bool) -> Node -> Node
|
||||||
convertNode = rmap go
|
convertNode isFoldable = rmap go
|
||||||
where
|
where
|
||||||
go :: ([BinderChange] -> Node -> Node) -> Node -> Node
|
go :: ([BinderChange] -> Node -> Node) -> Node -> Node
|
||||||
go recur = \case
|
go recur = \case
|
||||||
NLet Let {..}
|
NLet Let {..}
|
||||||
| isImmediate (_letItem ^. letItemValue) ->
|
| isImmediate (_letItem ^. letItemValue)
|
||||||
|
|| isFoldable (_letItem ^. letItemValue) ->
|
||||||
go (recur . (mkBCRemove (_letItem ^. letItemBinder) val' :)) _letBody
|
go (recur . (mkBCRemove (_letItem ^. letItemBinder) val' :)) _letBody
|
||||||
where
|
where
|
||||||
val' = go recur (_letItem ^. letItemValue)
|
val' = go recur (_letItem ^. letItemValue)
|
||||||
node ->
|
node ->
|
||||||
recur [] node
|
recur [] node
|
||||||
|
|
||||||
|
letFolding' :: (Node -> Bool) -> InfoTable -> InfoTable
|
||||||
|
letFolding' isFoldable = mapAllNodes (convertNode isFoldable)
|
||||||
|
|
||||||
letFolding :: InfoTable -> InfoTable
|
letFolding :: InfoTable -> InfoTable
|
||||||
letFolding = mapAllNodes convertNode
|
letFolding = letFolding' (const False)
|
||||||
|
@ -0,0 +1,11 @@
|
|||||||
|
module Juvix.Compiler.Core.Transformation.Optimize.Phase.Eval where
|
||||||
|
|
||||||
|
import Juvix.Compiler.Core.Options
|
||||||
|
import Juvix.Compiler.Core.Transformation.Base
|
||||||
|
import Juvix.Compiler.Core.Transformation.Optimize.LambdaFolding
|
||||||
|
import Juvix.Compiler.Core.Transformation.Optimize.LetFolding
|
||||||
|
|
||||||
|
optimize :: Member (Reader CoreOptions) r => InfoTable -> Sem r InfoTable
|
||||||
|
optimize =
|
||||||
|
withOptimizationLevel 1 $
|
||||||
|
return . letFolding . lambdaFolding
|
@ -0,0 +1,16 @@
|
|||||||
|
module Juvix.Compiler.Core.Transformation.Optimize.Phase.Exec where
|
||||||
|
|
||||||
|
import Juvix.Compiler.Core.Options
|
||||||
|
import Juvix.Compiler.Core.Transformation.Base
|
||||||
|
import Juvix.Compiler.Core.Transformation.LambdaLetRecLifting
|
||||||
|
import Juvix.Compiler.Core.Transformation.Optimize.LetFolding
|
||||||
|
import Juvix.Compiler.Core.Transformation.Optimize.Phase.Main qualified as Main
|
||||||
|
|
||||||
|
optimize :: Member (Reader CoreOptions) r => InfoTable -> Sem r InfoTable
|
||||||
|
optimize tab = do
|
||||||
|
opts <- ask
|
||||||
|
withOptimizationLevel' tab 1 $
|
||||||
|
return
|
||||||
|
. letFolding
|
||||||
|
. lambdaLetRecLifting
|
||||||
|
. Main.optimize' opts
|
@ -0,0 +1,8 @@
|
|||||||
|
module Juvix.Compiler.Core.Transformation.Optimize.Phase.Geb where
|
||||||
|
|
||||||
|
import Juvix.Compiler.Core.Options
|
||||||
|
import Juvix.Compiler.Core.Transformation.Base
|
||||||
|
import Juvix.Compiler.Core.Transformation.Optimize.Phase.Main qualified as Main
|
||||||
|
|
||||||
|
optimize :: Member (Reader CoreOptions) r => InfoTable -> Sem r InfoTable
|
||||||
|
optimize = withOptimizationLevel 1 Main.optimize
|
@ -0,0 +1,24 @@
|
|||||||
|
module Juvix.Compiler.Core.Transformation.Optimize.Phase.Main where
|
||||||
|
|
||||||
|
import Juvix.Compiler.Core.Data.IdentDependencyInfo
|
||||||
|
import Juvix.Compiler.Core.Options
|
||||||
|
import Juvix.Compiler.Core.Transformation.Base
|
||||||
|
import Juvix.Compiler.Core.Transformation.Optimize.Inlining
|
||||||
|
import Juvix.Compiler.Core.Transformation.Optimize.LambdaFolding
|
||||||
|
import Juvix.Compiler.Core.Transformation.Optimize.LetFolding
|
||||||
|
|
||||||
|
optimize' :: CoreOptions -> InfoTable -> InfoTable
|
||||||
|
optimize' CoreOptions {..} tab =
|
||||||
|
compose
|
||||||
|
(4 * _optOptimizationLevel)
|
||||||
|
( compose 2 (letFolding' (isInlineableLambda _optInliningDepth))
|
||||||
|
. lambdaFolding
|
||||||
|
. inlining' _optInliningDepth (recursiveIdents tab)
|
||||||
|
)
|
||||||
|
. letFolding
|
||||||
|
$ tab
|
||||||
|
|
||||||
|
optimize :: Member (Reader CoreOptions) r => InfoTable -> Sem r InfoTable
|
||||||
|
optimize tab = do
|
||||||
|
opts <- ask
|
||||||
|
return $ optimize' opts tab
|
@ -256,7 +256,11 @@ preFunctionDef f = do
|
|||||||
_identifierArgsNum = 0,
|
_identifierArgsNum = 0,
|
||||||
_identifierIsExported = False,
|
_identifierIsExported = False,
|
||||||
_identifierBuiltin = f ^. Internal.funDefBuiltin,
|
_identifierBuiltin = f ^. Internal.funDefBuiltin,
|
||||||
_identifierPragmas = f ^. Internal.funDefPragmas
|
_identifierPragmas =
|
||||||
|
over
|
||||||
|
pragmasInline
|
||||||
|
(fmap (adjustPragmaInline (implicitParametersNum (f ^. Internal.funDefType))))
|
||||||
|
(f ^. Internal.funDefPragmas)
|
||||||
}
|
}
|
||||||
case f ^. Internal.funDefBuiltin of
|
case f ^. Internal.funDefBuiltin of
|
||||||
Just b
|
Just b
|
||||||
@ -279,6 +283,18 @@ preFunctionDef f = do
|
|||||||
">=" -> Str.natGe
|
">=" -> Str.natGe
|
||||||
_ -> name
|
_ -> name
|
||||||
|
|
||||||
|
implicitParametersNum :: Internal.Expression -> Int
|
||||||
|
implicitParametersNum = \case
|
||||||
|
Internal.ExpressionFunction Internal.Function {..}
|
||||||
|
| _functionLeft ^. Internal.paramImplicit == Implicit ->
|
||||||
|
implicitParametersNum _functionRight + 1
|
||||||
|
_ -> 0
|
||||||
|
|
||||||
|
adjustPragmaInline :: Int -> PragmaInline -> PragmaInline
|
||||||
|
adjustPragmaInline n = \case
|
||||||
|
InlinePartiallyApplied k -> InlinePartiallyApplied (k + n)
|
||||||
|
x -> x
|
||||||
|
|
||||||
goFunctionDef ::
|
goFunctionDef ::
|
||||||
forall r.
|
forall r.
|
||||||
(Members '[InfoTableBuilder, Reader Internal.InfoTable, Reader InternalTyped.TypesTable, State InternalTyped.FunctionsTable] r) =>
|
(Members '[InfoTableBuilder, Reader Internal.InfoTable, Reader InternalTyped.TypesTable, State InternalTyped.FunctionsTable] r) =>
|
||||||
|
@ -27,6 +27,8 @@ data EntryPoint = EntryPoint
|
|||||||
_entryPointTarget :: Target,
|
_entryPointTarget :: Target,
|
||||||
_entryPointDebug :: Bool,
|
_entryPointDebug :: Bool,
|
||||||
_entryPointUnrollLimit :: Int,
|
_entryPointUnrollLimit :: Int,
|
||||||
|
_entryPointOptimizationLevel :: Int,
|
||||||
|
_entryPointInliningDepth :: Int,
|
||||||
_entryPointGenericOptions :: GenericOptions,
|
_entryPointGenericOptions :: GenericOptions,
|
||||||
_entryPointModulePaths :: [Path Abs File]
|
_entryPointModulePaths :: [Path Abs File]
|
||||||
}
|
}
|
||||||
@ -69,11 +71,19 @@ defaultEntryPointNoFile roots =
|
|||||||
_entryPointTarget = TargetCore,
|
_entryPointTarget = TargetCore,
|
||||||
_entryPointDebug = False,
|
_entryPointDebug = False,
|
||||||
_entryPointUnrollLimit = defaultUnrollLimit,
|
_entryPointUnrollLimit = defaultUnrollLimit,
|
||||||
|
_entryPointOptimizationLevel = defaultOptimizationLevel,
|
||||||
|
_entryPointInliningDepth = defaultInliningDepth,
|
||||||
_entryPointModulePaths = []
|
_entryPointModulePaths = []
|
||||||
}
|
}
|
||||||
|
|
||||||
defaultUnrollLimit :: Int
|
defaultUnrollLimit :: Int
|
||||||
defaultUnrollLimit = 140
|
defaultUnrollLimit = 140
|
||||||
|
|
||||||
|
defaultOptimizationLevel :: Int
|
||||||
|
defaultOptimizationLevel = 1
|
||||||
|
|
||||||
|
defaultInliningDepth :: Int
|
||||||
|
defaultInliningDepth = 2
|
||||||
|
|
||||||
mainModulePath :: Traversal' EntryPoint (Path Abs File)
|
mainModulePath :: Traversal' EntryPoint (Path Abs File)
|
||||||
mainModulePath = entryPointModulePaths . _head
|
mainModulePath = entryPointModulePaths . _head
|
||||||
|
@ -65,3 +65,11 @@ buildSCCs = Graph.stronglyConnComp . (^. depInfoEdgeList)
|
|||||||
|
|
||||||
isCyclic :: Ord n => DependencyInfo n -> Bool
|
isCyclic :: Ord n => DependencyInfo n -> Bool
|
||||||
isCyclic = any (\case CyclicSCC _ -> True; _ -> False) . buildSCCs
|
isCyclic = any (\case CyclicSCC _ -> True; _ -> False) . buildSCCs
|
||||||
|
|
||||||
|
nodesOnCycles :: forall n. (Hashable n, Ord n) => DependencyInfo n -> HashSet n
|
||||||
|
nodesOnCycles = foldr go mempty . buildSCCs
|
||||||
|
where
|
||||||
|
go :: SCC n -> HashSet n -> HashSet n
|
||||||
|
go x acc = case x of
|
||||||
|
CyclicSCC ns -> foldr HashSet.insert acc ns
|
||||||
|
_ -> acc
|
||||||
|
@ -199,6 +199,14 @@ traverseM ::
|
|||||||
f (m a2)
|
f (m a2)
|
||||||
traverseM f = fmap join . traverse f
|
traverseM f = fmap join . traverse f
|
||||||
|
|
||||||
|
composeM :: Monad m => Int -> (a -> m a) -> a -> m a
|
||||||
|
composeM 0 _ a = return a
|
||||||
|
composeM n f a = composeM (n - 1) f a >>= f
|
||||||
|
|
||||||
|
compose :: Int -> (a -> a) -> a -> a
|
||||||
|
compose 0 _ a = a
|
||||||
|
compose n f a = f (compose (n - 1) f a)
|
||||||
|
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
mapReader :: Member (Reader e1) r => (e1 -> e2) -> Sem (Reader e2 ': r) a -> Sem r a
|
mapReader :: Member (Reader e1) r => (e1 -> e2) -> Sem (Reader e2 ': r) a -> Sem r a
|
||||||
|
@ -316,8 +316,13 @@ tests =
|
|||||||
$(mkRelFile "test051.juvix")
|
$(mkRelFile "test051.juvix")
|
||||||
$(mkRelFile "out/test051.out"),
|
$(mkRelFile "out/test051.out"),
|
||||||
posTest
|
posTest
|
||||||
"Test052: Mutually recursive types, simple lambda calculus"
|
"Test052: Simple lambda calculus"
|
||||||
$(mkRelDir ".")
|
$(mkRelDir ".")
|
||||||
$(mkRelFile "test052.juvix")
|
$(mkRelFile "test052.juvix")
|
||||||
$(mkRelFile "out/test052.out")
|
$(mkRelFile "out/test052.out"),
|
||||||
|
posTest
|
||||||
|
"Test053: Inlining"
|
||||||
|
$(mkRelDir ".")
|
||||||
|
$(mkRelFile "test053.juvix")
|
||||||
|
$(mkRelFile "out/test053.out")
|
||||||
]
|
]
|
||||||
|
1
tests/Compilation/positive/out/test053.out
Normal file
1
tests/Compilation/positive/out/test053.out
Normal file
@ -0,0 +1 @@
|
|||||||
|
16
|
@ -1,4 +1,4 @@
|
|||||||
--- This module defines a simple lambda calculus and an evaluator.
|
--- Simple lambda claculus
|
||||||
module test052;
|
module test052;
|
||||||
|
|
||||||
open import Stdlib.Prelude;
|
open import Stdlib.Prelude;
|
||||||
|
23
tests/Compilation/positive/test053.juvix
Normal file
23
tests/Compilation/positive/test053.juvix
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
-- Inlining
|
||||||
|
module test052;
|
||||||
|
|
||||||
|
open import Stdlib.Prelude;
|
||||||
|
|
||||||
|
{-# inline: 2 #-}
|
||||||
|
mycompose : {A B C : Type} -> (B -> C) -> (A -> B) -> A -> C;
|
||||||
|
mycompose f g x := f (g x);
|
||||||
|
|
||||||
|
{-# inline: true #-}
|
||||||
|
myconst : {A B : Type} -> A -> B -> A;
|
||||||
|
myconst x _ := x;
|
||||||
|
|
||||||
|
{-# inline: 1 #-}
|
||||||
|
myflip : {A B C : Type} -> (A -> B -> C) -> B -> A -> C;
|
||||||
|
myflip f b a := f a b;
|
||||||
|
|
||||||
|
main : Nat;
|
||||||
|
main :=
|
||||||
|
let f : Nat -> Nat := mycompose λ{x := x + 1} λ{x := x * 2};
|
||||||
|
g : Nat -> Nat -> Nat := myflip myconst;
|
||||||
|
in
|
||||||
|
f 3 + g 7 9;
|
Loading…
Reference in New Issue
Block a user