1
1
mirror of https://github.com/anoma/juvix.git synced 2024-12-24 16:12:14 +03:00

Inline non-recursive functions with only one call site (#3204)

* Closes #3198
This commit is contained in:
Łukasz Czajka 2024-12-04 12:00:58 +01:00 committed by GitHub
parent af9679d557
commit c79f5e3462
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 51 additions and 9 deletions

View File

@ -640,3 +640,32 @@ isDirectlyRecursive md sym = ufold (\x xs -> or (x : xs)) go (lookupIdentifierNo
go = \case
NIdt Ident {..} -> _identSymbol == sym
_ -> False
-- Returns a map from symbols to their number of occurrences in the given node.
getSymbolsMap :: Module -> Node -> HashMap Symbol Int
getSymbolsMap md = gather go mempty
where
go :: HashMap Symbol Int -> Node -> HashMap Symbol Int
go acc = \case
NTyp TypeConstr {..} -> mapInc _typeConstrSymbol acc
NIdt Ident {..} -> mapInc _identSymbol acc
NCase Case {..} -> mapInc _caseInductive acc
NCtr Constr {..}
| Just ci <- lookupConstructorInfo' md _constrTag ->
mapInc (ci ^. constructorInductive) acc
_ -> acc
mapInc :: Symbol -> HashMap Symbol Int -> HashMap Symbol Int
mapInc k = HashMap.insertWith (+) k 1
getTableSymbolsMap :: InfoTable -> HashMap Symbol Int
getTableSymbolsMap tab =
foldr
(HashMap.unionWith (+))
mempty
(map (getSymbolsMap md) (HashMap.elems $ tab ^. identContext))
where
md = emptyModule {_moduleInfoTable = tab}
getModuleSymbolsMap :: Module -> HashMap Symbol Int
getModuleSymbolsMap = getTableSymbolsMap . computeCombinedInfoTable

View File

@ -1,5 +1,6 @@
module Juvix.Compiler.Core.Transformation.Optimize.Inlining where
import Data.HashMap.Strict qualified as HashMap
import Data.HashSet qualified as HashSet
import Juvix.Compiler.Core.Data.BinderList qualified as BL
import Juvix.Compiler.Core.Data.IdentDependencyInfo
@ -16,8 +17,8 @@ isInlineableLambda inlineDepth md bl node = case node of
_ ->
False
convertNode :: Int -> HashSet Symbol -> Module -> Node -> Node
convertNode inlineDepth nonRecSyms md = dmapL go
convertNode :: Int -> HashSet Symbol -> HashMap Symbol Int -> Module -> Node -> Node
convertNode inlineDepth nonRecSyms symOcc md = dmapL go
where
go :: BinderList Binder -> Node -> Node
go bl node = case node of
@ -39,7 +40,9 @@ convertNode inlineDepth nonRecSyms md = dmapL go
_
| HashSet.member _identSymbol nonRecSyms
&& length args >= argsNum
&& isInlineableLambda inlineDepth md bl def ->
&& ( HashMap.lookup _identSymbol symOcc == Just 1
|| isInlineableLambda inlineDepth md bl def
) ->
mkApps def args
_ ->
node
@ -58,7 +61,10 @@ convertNode inlineDepth nonRecSyms md = dmapL go
Just InlineNever -> node
_
| HashSet.member _identSymbol nonRecSyms
&& isImmediate md def ->
&& argsNum == 0
&& ( HashMap.lookup _identSymbol symOcc == Just 1
|| isImmediate md def
) ->
def
| otherwise ->
node
@ -98,10 +104,10 @@ convertNode inlineDepth nonRecSyms md = dmapL go
where
(lamsNum, body) = unfoldLambdas' node
inlining' :: Int -> HashSet Symbol -> Module -> Module
inlining' inliningDepth nonRecSyms md = mapT (const (convertNode inliningDepth nonRecSyms md)) md
inlining' :: Int -> HashSet Symbol -> HashMap Symbol Int -> Module -> Module
inlining' inliningDepth nonRecSyms symOcc md = mapT (const (convertNode inliningDepth nonRecSyms symOcc md)) md
inlining :: (Member (Reader CoreOptions) r) => Module -> Sem r Module
inlining md = do
d <- asks (^. optInliningDepth)
return $ inlining' d (nonRecursiveIdents md) md
return $ inlining' d (nonRecursiveIdents md) (getModuleSymbolsMap md) md

View File

@ -1,6 +1,7 @@
module Juvix.Compiler.Core.Transformation.Optimize.Phase.Main where
import Juvix.Compiler.Core.Data.IdentDependencyInfo
import Juvix.Compiler.Core.Extra.Utils (getTableSymbolsMap)
import Juvix.Compiler.Core.Options
import Juvix.Compiler.Core.Transformation.Base
import Juvix.Compiler.Core.Transformation.Optimize.CaseFolding
@ -39,6 +40,9 @@ optimize' opts@CoreOptions {..} md =
nonRecsReachable :: HashSet Symbol
nonRecsReachable = nonRecursiveReachableIdents' tab
symOcc :: HashMap Symbol Int
symOcc = getTableSymbolsMap tab
doConstantFolding :: Module -> Module
doConstantFolding md' = constantFolding' opts nonRecs' tab' md'
where
@ -48,7 +52,7 @@ optimize' opts@CoreOptions {..} md =
| otherwise = nonRecsReachable
doInlining :: Module -> Module
doInlining md' = inlining' _optInliningDepth nonRecs' md'
doInlining md' = inlining' _optInliningDepth nonRecs' symOcc md'
where
nonRecs' =
if

View File

@ -1,6 +1,7 @@
module Juvix.Compiler.Core.Transformation.Optimize.Phase.PreLifting where
import Juvix.Compiler.Core.Data.IdentDependencyInfo
import Juvix.Compiler.Core.Extra.Utils
import Juvix.Compiler.Core.Options
import Juvix.Compiler.Core.Transformation.Base
import Juvix.Compiler.Core.Transformation.Optimize.CaseFolding
@ -23,8 +24,9 @@ optimize md = do
2
( compose 2 (letFolding' (isInlineableLambda _optInliningDepth))
. lambdaFolding
. inlining' _optInliningDepth nonRecSyms
. inlining' _optInliningDepth nonRecSyms symOcc
)
. letFolding
where
nonRecSyms = nonRecursiveIdents md
symOcc = getModuleSymbolsMap md

View File

@ -1,3 +1,4 @@
-- Patterns in definitions
module test086;
import Stdlib.Prelude open;