1
1
mirror of https://github.com/anoma/juvix.git synced 2024-12-25 08:34:10 +03:00

Inline non-recursive functions with only one call site (#3204)

* Closes #3198
This commit is contained in:
Łukasz Czajka 2024-12-04 12:00:58 +01:00 committed by GitHub
parent af9679d557
commit c79f5e3462
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 51 additions and 9 deletions

View File

@ -640,3 +640,32 @@ isDirectlyRecursive md sym = ufold (\x xs -> or (x : xs)) go (lookupIdentifierNo
go = \case go = \case
NIdt Ident {..} -> _identSymbol == sym NIdt Ident {..} -> _identSymbol == sym
_ -> False _ -> False
-- Returns a map from symbols to their number of occurrences in the given node.
getSymbolsMap :: Module -> Node -> HashMap Symbol Int
getSymbolsMap md = gather go mempty
where
go :: HashMap Symbol Int -> Node -> HashMap Symbol Int
go acc = \case
NTyp TypeConstr {..} -> mapInc _typeConstrSymbol acc
NIdt Ident {..} -> mapInc _identSymbol acc
NCase Case {..} -> mapInc _caseInductive acc
NCtr Constr {..}
| Just ci <- lookupConstructorInfo' md _constrTag ->
mapInc (ci ^. constructorInductive) acc
_ -> acc
mapInc :: Symbol -> HashMap Symbol Int -> HashMap Symbol Int
mapInc k = HashMap.insertWith (+) k 1
getTableSymbolsMap :: InfoTable -> HashMap Symbol Int
getTableSymbolsMap tab =
foldr
(HashMap.unionWith (+))
mempty
(map (getSymbolsMap md) (HashMap.elems $ tab ^. identContext))
where
md = emptyModule {_moduleInfoTable = tab}
getModuleSymbolsMap :: Module -> HashMap Symbol Int
getModuleSymbolsMap = getTableSymbolsMap . computeCombinedInfoTable

View File

@ -1,5 +1,6 @@
module Juvix.Compiler.Core.Transformation.Optimize.Inlining where module Juvix.Compiler.Core.Transformation.Optimize.Inlining where
import Data.HashMap.Strict qualified as HashMap
import Data.HashSet qualified as HashSet import Data.HashSet qualified as HashSet
import Juvix.Compiler.Core.Data.BinderList qualified as BL import Juvix.Compiler.Core.Data.BinderList qualified as BL
import Juvix.Compiler.Core.Data.IdentDependencyInfo import Juvix.Compiler.Core.Data.IdentDependencyInfo
@ -16,8 +17,8 @@ isInlineableLambda inlineDepth md bl node = case node of
_ -> _ ->
False False
convertNode :: Int -> HashSet Symbol -> Module -> Node -> Node convertNode :: Int -> HashSet Symbol -> HashMap Symbol Int -> Module -> Node -> Node
convertNode inlineDepth nonRecSyms md = dmapL go convertNode inlineDepth nonRecSyms symOcc md = dmapL go
where where
go :: BinderList Binder -> Node -> Node go :: BinderList Binder -> Node -> Node
go bl node = case node of go bl node = case node of
@ -39,7 +40,9 @@ convertNode inlineDepth nonRecSyms md = dmapL go
_ _
| HashSet.member _identSymbol nonRecSyms | HashSet.member _identSymbol nonRecSyms
&& length args >= argsNum && length args >= argsNum
&& isInlineableLambda inlineDepth md bl def -> && ( HashMap.lookup _identSymbol symOcc == Just 1
|| isInlineableLambda inlineDepth md bl def
) ->
mkApps def args mkApps def args
_ -> _ ->
node node
@ -58,7 +61,10 @@ convertNode inlineDepth nonRecSyms md = dmapL go
Just InlineNever -> node Just InlineNever -> node
_ _
| HashSet.member _identSymbol nonRecSyms | HashSet.member _identSymbol nonRecSyms
&& isImmediate md def -> && argsNum == 0
&& ( HashMap.lookup _identSymbol symOcc == Just 1
|| isImmediate md def
) ->
def def
| otherwise -> | otherwise ->
node node
@ -98,10 +104,10 @@ convertNode inlineDepth nonRecSyms md = dmapL go
where where
(lamsNum, body) = unfoldLambdas' node (lamsNum, body) = unfoldLambdas' node
inlining' :: Int -> HashSet Symbol -> Module -> Module inlining' :: Int -> HashSet Symbol -> HashMap Symbol Int -> Module -> Module
inlining' inliningDepth nonRecSyms md = mapT (const (convertNode inliningDepth nonRecSyms md)) md inlining' inliningDepth nonRecSyms symOcc md = mapT (const (convertNode inliningDepth nonRecSyms symOcc md)) md
inlining :: (Member (Reader CoreOptions) r) => Module -> Sem r Module inlining :: (Member (Reader CoreOptions) r) => Module -> Sem r Module
inlining md = do inlining md = do
d <- asks (^. optInliningDepth) d <- asks (^. optInliningDepth)
return $ inlining' d (nonRecursiveIdents md) md return $ inlining' d (nonRecursiveIdents md) (getModuleSymbolsMap md) md

View File

@ -1,6 +1,7 @@
module Juvix.Compiler.Core.Transformation.Optimize.Phase.Main where module Juvix.Compiler.Core.Transformation.Optimize.Phase.Main where
import Juvix.Compiler.Core.Data.IdentDependencyInfo import Juvix.Compiler.Core.Data.IdentDependencyInfo
import Juvix.Compiler.Core.Extra.Utils (getTableSymbolsMap)
import Juvix.Compiler.Core.Options import Juvix.Compiler.Core.Options
import Juvix.Compiler.Core.Transformation.Base import Juvix.Compiler.Core.Transformation.Base
import Juvix.Compiler.Core.Transformation.Optimize.CaseFolding import Juvix.Compiler.Core.Transformation.Optimize.CaseFolding
@ -39,6 +40,9 @@ optimize' opts@CoreOptions {..} md =
nonRecsReachable :: HashSet Symbol nonRecsReachable :: HashSet Symbol
nonRecsReachable = nonRecursiveReachableIdents' tab nonRecsReachable = nonRecursiveReachableIdents' tab
symOcc :: HashMap Symbol Int
symOcc = getTableSymbolsMap tab
doConstantFolding :: Module -> Module doConstantFolding :: Module -> Module
doConstantFolding md' = constantFolding' opts nonRecs' tab' md' doConstantFolding md' = constantFolding' opts nonRecs' tab' md'
where where
@ -48,7 +52,7 @@ optimize' opts@CoreOptions {..} md =
| otherwise = nonRecsReachable | otherwise = nonRecsReachable
doInlining :: Module -> Module doInlining :: Module -> Module
doInlining md' = inlining' _optInliningDepth nonRecs' md' doInlining md' = inlining' _optInliningDepth nonRecs' symOcc md'
where where
nonRecs' = nonRecs' =
if if

View File

@ -1,6 +1,7 @@
module Juvix.Compiler.Core.Transformation.Optimize.Phase.PreLifting where module Juvix.Compiler.Core.Transformation.Optimize.Phase.PreLifting where
import Juvix.Compiler.Core.Data.IdentDependencyInfo import Juvix.Compiler.Core.Data.IdentDependencyInfo
import Juvix.Compiler.Core.Extra.Utils
import Juvix.Compiler.Core.Options import Juvix.Compiler.Core.Options
import Juvix.Compiler.Core.Transformation.Base import Juvix.Compiler.Core.Transformation.Base
import Juvix.Compiler.Core.Transformation.Optimize.CaseFolding import Juvix.Compiler.Core.Transformation.Optimize.CaseFolding
@ -23,8 +24,9 @@ optimize md = do
2 2
( compose 2 (letFolding' (isInlineableLambda _optInliningDepth)) ( compose 2 (letFolding' (isInlineableLambda _optInliningDepth))
. lambdaFolding . lambdaFolding
. inlining' _optInliningDepth nonRecSyms . inlining' _optInliningDepth nonRecSyms symOcc
) )
. letFolding . letFolding
where where
nonRecSyms = nonRecursiveIdents md nonRecSyms = nonRecursiveIdents md
symOcc = getModuleSymbolsMap md

View File

@ -1,3 +1,4 @@
-- Patterns in definitions
module test086; module test086;
import Stdlib.Prelude open; import Stdlib.Prelude open;