mirror of
https://github.com/anoma/juvix.git
synced 2024-12-25 08:34:10 +03:00
Inline non-recursive functions with only one call site (#3204)
* Closes #3198
This commit is contained in:
parent
af9679d557
commit
c79f5e3462
@ -640,3 +640,32 @@ isDirectlyRecursive md sym = ufold (\x xs -> or (x : xs)) go (lookupIdentifierNo
|
|||||||
go = \case
|
go = \case
|
||||||
NIdt Ident {..} -> _identSymbol == sym
|
NIdt Ident {..} -> _identSymbol == sym
|
||||||
_ -> False
|
_ -> False
|
||||||
|
|
||||||
|
-- Returns a map from symbols to their number of occurrences in the given node.
|
||||||
|
getSymbolsMap :: Module -> Node -> HashMap Symbol Int
|
||||||
|
getSymbolsMap md = gather go mempty
|
||||||
|
where
|
||||||
|
go :: HashMap Symbol Int -> Node -> HashMap Symbol Int
|
||||||
|
go acc = \case
|
||||||
|
NTyp TypeConstr {..} -> mapInc _typeConstrSymbol acc
|
||||||
|
NIdt Ident {..} -> mapInc _identSymbol acc
|
||||||
|
NCase Case {..} -> mapInc _caseInductive acc
|
||||||
|
NCtr Constr {..}
|
||||||
|
| Just ci <- lookupConstructorInfo' md _constrTag ->
|
||||||
|
mapInc (ci ^. constructorInductive) acc
|
||||||
|
_ -> acc
|
||||||
|
|
||||||
|
mapInc :: Symbol -> HashMap Symbol Int -> HashMap Symbol Int
|
||||||
|
mapInc k = HashMap.insertWith (+) k 1
|
||||||
|
|
||||||
|
getTableSymbolsMap :: InfoTable -> HashMap Symbol Int
|
||||||
|
getTableSymbolsMap tab =
|
||||||
|
foldr
|
||||||
|
(HashMap.unionWith (+))
|
||||||
|
mempty
|
||||||
|
(map (getSymbolsMap md) (HashMap.elems $ tab ^. identContext))
|
||||||
|
where
|
||||||
|
md = emptyModule {_moduleInfoTable = tab}
|
||||||
|
|
||||||
|
getModuleSymbolsMap :: Module -> HashMap Symbol Int
|
||||||
|
getModuleSymbolsMap = getTableSymbolsMap . computeCombinedInfoTable
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
module Juvix.Compiler.Core.Transformation.Optimize.Inlining where
|
module Juvix.Compiler.Core.Transformation.Optimize.Inlining where
|
||||||
|
|
||||||
|
import Data.HashMap.Strict qualified as HashMap
|
||||||
import Data.HashSet qualified as HashSet
|
import Data.HashSet qualified as HashSet
|
||||||
import Juvix.Compiler.Core.Data.BinderList qualified as BL
|
import Juvix.Compiler.Core.Data.BinderList qualified as BL
|
||||||
import Juvix.Compiler.Core.Data.IdentDependencyInfo
|
import Juvix.Compiler.Core.Data.IdentDependencyInfo
|
||||||
@ -16,8 +17,8 @@ isInlineableLambda inlineDepth md bl node = case node of
|
|||||||
_ ->
|
_ ->
|
||||||
False
|
False
|
||||||
|
|
||||||
convertNode :: Int -> HashSet Symbol -> Module -> Node -> Node
|
convertNode :: Int -> HashSet Symbol -> HashMap Symbol Int -> Module -> Node -> Node
|
||||||
convertNode inlineDepth nonRecSyms md = dmapL go
|
convertNode inlineDepth nonRecSyms symOcc md = dmapL go
|
||||||
where
|
where
|
||||||
go :: BinderList Binder -> Node -> Node
|
go :: BinderList Binder -> Node -> Node
|
||||||
go bl node = case node of
|
go bl node = case node of
|
||||||
@ -39,7 +40,9 @@ convertNode inlineDepth nonRecSyms md = dmapL go
|
|||||||
_
|
_
|
||||||
| HashSet.member _identSymbol nonRecSyms
|
| HashSet.member _identSymbol nonRecSyms
|
||||||
&& length args >= argsNum
|
&& length args >= argsNum
|
||||||
&& isInlineableLambda inlineDepth md bl def ->
|
&& ( HashMap.lookup _identSymbol symOcc == Just 1
|
||||||
|
|| isInlineableLambda inlineDepth md bl def
|
||||||
|
) ->
|
||||||
mkApps def args
|
mkApps def args
|
||||||
_ ->
|
_ ->
|
||||||
node
|
node
|
||||||
@ -58,7 +61,10 @@ convertNode inlineDepth nonRecSyms md = dmapL go
|
|||||||
Just InlineNever -> node
|
Just InlineNever -> node
|
||||||
_
|
_
|
||||||
| HashSet.member _identSymbol nonRecSyms
|
| HashSet.member _identSymbol nonRecSyms
|
||||||
&& isImmediate md def ->
|
&& argsNum == 0
|
||||||
|
&& ( HashMap.lookup _identSymbol symOcc == Just 1
|
||||||
|
|| isImmediate md def
|
||||||
|
) ->
|
||||||
def
|
def
|
||||||
| otherwise ->
|
| otherwise ->
|
||||||
node
|
node
|
||||||
@ -98,10 +104,10 @@ convertNode inlineDepth nonRecSyms md = dmapL go
|
|||||||
where
|
where
|
||||||
(lamsNum, body) = unfoldLambdas' node
|
(lamsNum, body) = unfoldLambdas' node
|
||||||
|
|
||||||
inlining' :: Int -> HashSet Symbol -> Module -> Module
|
inlining' :: Int -> HashSet Symbol -> HashMap Symbol Int -> Module -> Module
|
||||||
inlining' inliningDepth nonRecSyms md = mapT (const (convertNode inliningDepth nonRecSyms md)) md
|
inlining' inliningDepth nonRecSyms symOcc md = mapT (const (convertNode inliningDepth nonRecSyms symOcc md)) md
|
||||||
|
|
||||||
inlining :: (Member (Reader CoreOptions) r) => Module -> Sem r Module
|
inlining :: (Member (Reader CoreOptions) r) => Module -> Sem r Module
|
||||||
inlining md = do
|
inlining md = do
|
||||||
d <- asks (^. optInliningDepth)
|
d <- asks (^. optInliningDepth)
|
||||||
return $ inlining' d (nonRecursiveIdents md) md
|
return $ inlining' d (nonRecursiveIdents md) (getModuleSymbolsMap md) md
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
module Juvix.Compiler.Core.Transformation.Optimize.Phase.Main where
|
module Juvix.Compiler.Core.Transformation.Optimize.Phase.Main where
|
||||||
|
|
||||||
import Juvix.Compiler.Core.Data.IdentDependencyInfo
|
import Juvix.Compiler.Core.Data.IdentDependencyInfo
|
||||||
|
import Juvix.Compiler.Core.Extra.Utils (getTableSymbolsMap)
|
||||||
import Juvix.Compiler.Core.Options
|
import Juvix.Compiler.Core.Options
|
||||||
import Juvix.Compiler.Core.Transformation.Base
|
import Juvix.Compiler.Core.Transformation.Base
|
||||||
import Juvix.Compiler.Core.Transformation.Optimize.CaseFolding
|
import Juvix.Compiler.Core.Transformation.Optimize.CaseFolding
|
||||||
@ -39,6 +40,9 @@ optimize' opts@CoreOptions {..} md =
|
|||||||
nonRecsReachable :: HashSet Symbol
|
nonRecsReachable :: HashSet Symbol
|
||||||
nonRecsReachable = nonRecursiveReachableIdents' tab
|
nonRecsReachable = nonRecursiveReachableIdents' tab
|
||||||
|
|
||||||
|
symOcc :: HashMap Symbol Int
|
||||||
|
symOcc = getTableSymbolsMap tab
|
||||||
|
|
||||||
doConstantFolding :: Module -> Module
|
doConstantFolding :: Module -> Module
|
||||||
doConstantFolding md' = constantFolding' opts nonRecs' tab' md'
|
doConstantFolding md' = constantFolding' opts nonRecs' tab' md'
|
||||||
where
|
where
|
||||||
@ -48,7 +52,7 @@ optimize' opts@CoreOptions {..} md =
|
|||||||
| otherwise = nonRecsReachable
|
| otherwise = nonRecsReachable
|
||||||
|
|
||||||
doInlining :: Module -> Module
|
doInlining :: Module -> Module
|
||||||
doInlining md' = inlining' _optInliningDepth nonRecs' md'
|
doInlining md' = inlining' _optInliningDepth nonRecs' symOcc md'
|
||||||
where
|
where
|
||||||
nonRecs' =
|
nonRecs' =
|
||||||
if
|
if
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
module Juvix.Compiler.Core.Transformation.Optimize.Phase.PreLifting where
|
module Juvix.Compiler.Core.Transformation.Optimize.Phase.PreLifting where
|
||||||
|
|
||||||
import Juvix.Compiler.Core.Data.IdentDependencyInfo
|
import Juvix.Compiler.Core.Data.IdentDependencyInfo
|
||||||
|
import Juvix.Compiler.Core.Extra.Utils
|
||||||
import Juvix.Compiler.Core.Options
|
import Juvix.Compiler.Core.Options
|
||||||
import Juvix.Compiler.Core.Transformation.Base
|
import Juvix.Compiler.Core.Transformation.Base
|
||||||
import Juvix.Compiler.Core.Transformation.Optimize.CaseFolding
|
import Juvix.Compiler.Core.Transformation.Optimize.CaseFolding
|
||||||
@ -23,8 +24,9 @@ optimize md = do
|
|||||||
2
|
2
|
||||||
( compose 2 (letFolding' (isInlineableLambda _optInliningDepth))
|
( compose 2 (letFolding' (isInlineableLambda _optInliningDepth))
|
||||||
. lambdaFolding
|
. lambdaFolding
|
||||||
. inlining' _optInliningDepth nonRecSyms
|
. inlining' _optInliningDepth nonRecSyms symOcc
|
||||||
)
|
)
|
||||||
. letFolding
|
. letFolding
|
||||||
where
|
where
|
||||||
nonRecSyms = nonRecursiveIdents md
|
nonRecSyms = nonRecursiveIdents md
|
||||||
|
symOcc = getModuleSymbolsMap md
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
-- Patterns in definitions
|
||||||
module test086;
|
module test086;
|
||||||
|
|
||||||
import Stdlib.Prelude open;
|
import Stdlib.Prelude open;
|
||||||
|
Loading…
Reference in New Issue
Block a user