mirror of
https://github.com/anoma/juvix.git
synced 2024-12-24 16:12:14 +03:00
Inline non-recursive functions with only one call site (#3204)
* Closes #3198
This commit is contained in:
parent
af9679d557
commit
c79f5e3462
@ -640,3 +640,32 @@ isDirectlyRecursive md sym = ufold (\x xs -> or (x : xs)) go (lookupIdentifierNo
|
||||
go = \case
|
||||
NIdt Ident {..} -> _identSymbol == sym
|
||||
_ -> False
|
||||
|
||||
-- Returns a map from symbols to their number of occurrences in the given node.
|
||||
getSymbolsMap :: Module -> Node -> HashMap Symbol Int
|
||||
getSymbolsMap md = gather go mempty
|
||||
where
|
||||
go :: HashMap Symbol Int -> Node -> HashMap Symbol Int
|
||||
go acc = \case
|
||||
NTyp TypeConstr {..} -> mapInc _typeConstrSymbol acc
|
||||
NIdt Ident {..} -> mapInc _identSymbol acc
|
||||
NCase Case {..} -> mapInc _caseInductive acc
|
||||
NCtr Constr {..}
|
||||
| Just ci <- lookupConstructorInfo' md _constrTag ->
|
||||
mapInc (ci ^. constructorInductive) acc
|
||||
_ -> acc
|
||||
|
||||
mapInc :: Symbol -> HashMap Symbol Int -> HashMap Symbol Int
|
||||
mapInc k = HashMap.insertWith (+) k 1
|
||||
|
||||
getTableSymbolsMap :: InfoTable -> HashMap Symbol Int
|
||||
getTableSymbolsMap tab =
|
||||
foldr
|
||||
(HashMap.unionWith (+))
|
||||
mempty
|
||||
(map (getSymbolsMap md) (HashMap.elems $ tab ^. identContext))
|
||||
where
|
||||
md = emptyModule {_moduleInfoTable = tab}
|
||||
|
||||
getModuleSymbolsMap :: Module -> HashMap Symbol Int
|
||||
getModuleSymbolsMap = getTableSymbolsMap . computeCombinedInfoTable
|
||||
|
@ -1,5 +1,6 @@
|
||||
module Juvix.Compiler.Core.Transformation.Optimize.Inlining where
|
||||
|
||||
import Data.HashMap.Strict qualified as HashMap
|
||||
import Data.HashSet qualified as HashSet
|
||||
import Juvix.Compiler.Core.Data.BinderList qualified as BL
|
||||
import Juvix.Compiler.Core.Data.IdentDependencyInfo
|
||||
@ -16,8 +17,8 @@ isInlineableLambda inlineDepth md bl node = case node of
|
||||
_ ->
|
||||
False
|
||||
|
||||
convertNode :: Int -> HashSet Symbol -> Module -> Node -> Node
|
||||
convertNode inlineDepth nonRecSyms md = dmapL go
|
||||
convertNode :: Int -> HashSet Symbol -> HashMap Symbol Int -> Module -> Node -> Node
|
||||
convertNode inlineDepth nonRecSyms symOcc md = dmapL go
|
||||
where
|
||||
go :: BinderList Binder -> Node -> Node
|
||||
go bl node = case node of
|
||||
@ -39,7 +40,9 @@ convertNode inlineDepth nonRecSyms md = dmapL go
|
||||
_
|
||||
| HashSet.member _identSymbol nonRecSyms
|
||||
&& length args >= argsNum
|
||||
&& isInlineableLambda inlineDepth md bl def ->
|
||||
&& ( HashMap.lookup _identSymbol symOcc == Just 1
|
||||
|| isInlineableLambda inlineDepth md bl def
|
||||
) ->
|
||||
mkApps def args
|
||||
_ ->
|
||||
node
|
||||
@ -58,7 +61,10 @@ convertNode inlineDepth nonRecSyms md = dmapL go
|
||||
Just InlineNever -> node
|
||||
_
|
||||
| HashSet.member _identSymbol nonRecSyms
|
||||
&& isImmediate md def ->
|
||||
&& argsNum == 0
|
||||
&& ( HashMap.lookup _identSymbol symOcc == Just 1
|
||||
|| isImmediate md def
|
||||
) ->
|
||||
def
|
||||
| otherwise ->
|
||||
node
|
||||
@ -98,10 +104,10 @@ convertNode inlineDepth nonRecSyms md = dmapL go
|
||||
where
|
||||
(lamsNum, body) = unfoldLambdas' node
|
||||
|
||||
inlining' :: Int -> HashSet Symbol -> Module -> Module
|
||||
inlining' inliningDepth nonRecSyms md = mapT (const (convertNode inliningDepth nonRecSyms md)) md
|
||||
inlining' :: Int -> HashSet Symbol -> HashMap Symbol Int -> Module -> Module
|
||||
inlining' inliningDepth nonRecSyms symOcc md = mapT (const (convertNode inliningDepth nonRecSyms symOcc md)) md
|
||||
|
||||
inlining :: (Member (Reader CoreOptions) r) => Module -> Sem r Module
|
||||
inlining md = do
|
||||
d <- asks (^. optInliningDepth)
|
||||
return $ inlining' d (nonRecursiveIdents md) md
|
||||
return $ inlining' d (nonRecursiveIdents md) (getModuleSymbolsMap md) md
|
||||
|
@ -1,6 +1,7 @@
|
||||
module Juvix.Compiler.Core.Transformation.Optimize.Phase.Main where
|
||||
|
||||
import Juvix.Compiler.Core.Data.IdentDependencyInfo
|
||||
import Juvix.Compiler.Core.Extra.Utils (getTableSymbolsMap)
|
||||
import Juvix.Compiler.Core.Options
|
||||
import Juvix.Compiler.Core.Transformation.Base
|
||||
import Juvix.Compiler.Core.Transformation.Optimize.CaseFolding
|
||||
@ -39,6 +40,9 @@ optimize' opts@CoreOptions {..} md =
|
||||
nonRecsReachable :: HashSet Symbol
|
||||
nonRecsReachable = nonRecursiveReachableIdents' tab
|
||||
|
||||
symOcc :: HashMap Symbol Int
|
||||
symOcc = getTableSymbolsMap tab
|
||||
|
||||
doConstantFolding :: Module -> Module
|
||||
doConstantFolding md' = constantFolding' opts nonRecs' tab' md'
|
||||
where
|
||||
@ -48,7 +52,7 @@ optimize' opts@CoreOptions {..} md =
|
||||
| otherwise = nonRecsReachable
|
||||
|
||||
doInlining :: Module -> Module
|
||||
doInlining md' = inlining' _optInliningDepth nonRecs' md'
|
||||
doInlining md' = inlining' _optInliningDepth nonRecs' symOcc md'
|
||||
where
|
||||
nonRecs' =
|
||||
if
|
||||
|
@ -1,6 +1,7 @@
|
||||
module Juvix.Compiler.Core.Transformation.Optimize.Phase.PreLifting where
|
||||
|
||||
import Juvix.Compiler.Core.Data.IdentDependencyInfo
|
||||
import Juvix.Compiler.Core.Extra.Utils
|
||||
import Juvix.Compiler.Core.Options
|
||||
import Juvix.Compiler.Core.Transformation.Base
|
||||
import Juvix.Compiler.Core.Transformation.Optimize.CaseFolding
|
||||
@ -23,8 +24,9 @@ optimize md = do
|
||||
2
|
||||
( compose 2 (letFolding' (isInlineableLambda _optInliningDepth))
|
||||
. lambdaFolding
|
||||
. inlining' _optInliningDepth nonRecSyms
|
||||
. inlining' _optInliningDepth nonRecSyms symOcc
|
||||
)
|
||||
. letFolding
|
||||
where
|
||||
nonRecSyms = nonRecursiveIdents md
|
||||
symOcc = getModuleSymbolsMap md
|
||||
|
@ -1,3 +1,4 @@
|
||||
-- Patterns in definitions
|
||||
module test086;
|
||||
|
||||
import Stdlib.Prelude open;
|
||||
|
Loading…
Reference in New Issue
Block a user