From 39ef069bfcf670279a911695c70f79fc7f61a569 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Czajka?= <62751+lukaszcz@users.noreply.github.com> Date: Thu, 29 Jun 2023 13:02:10 +0200 Subject: [PATCH] Fold lets when the bound variable occurs at most once (#2231) For example, convert ``` let x := f a b c in g x ``` to ``` g (f a b c) ``` --- src/Juvix/Compiler/Core/Info/FreeVarsInfo.hs | 49 +++++++++++++++++++ .../Transformation/Optimize/LetFolding.hs | 18 +++++-- tests/Core/positive/test034.jvc | 4 +- 3 files changed, 64 insertions(+), 7 deletions(-) create mode 100644 src/Juvix/Compiler/Core/Info/FreeVarsInfo.hs diff --git a/src/Juvix/Compiler/Core/Info/FreeVarsInfo.hs b/src/Juvix/Compiler/Core/Info/FreeVarsInfo.hs new file mode 100644 index 000000000..3a5de863f --- /dev/null +++ b/src/Juvix/Compiler/Core/Info/FreeVarsInfo.hs @@ -0,0 +1,49 @@ +module Juvix.Compiler.Core.Info.FreeVarsInfo where + +import Data.Map qualified as Map +import Juvix.Compiler.Core.Extra +import Juvix.Compiler.Core.Info qualified as Info +import Juvix.Compiler.Core.Language + +newtype FreeVarsInfo = FreeVarsInfo + { -- map free variables to the number of their occurrences + _infoFreeVars :: Map Index Int + } + +instance IsInfo FreeVarsInfo + +kFreeVarsInfo :: Key FreeVarsInfo +kFreeVarsInfo = Proxy + +makeLenses ''FreeVarsInfo + +computeFreeVarsInfo :: Node -> Node +computeFreeVarsInfo = umap go + where + go :: Node -> Node + go node = case node of + NVar Var {..} -> + mkVar (Info.insert fvi _varInfo) _varIndex + where + fvi = FreeVarsInfo (Map.singleton _varIndex 1) + _ -> + modifyInfo (Info.insert fvi) node + where + fvi = + FreeVarsInfo $ + foldr + ( \NodeChild {..} acc -> + Map.unionWith (+) acc $ + Map.mapKeysMonotonic (\idx -> idx - _childBindersNum) $ + Map.filterWithKey + (\idx _ -> idx >= _childBindersNum) + (getFreeVarsInfo _childNode ^. infoFreeVars) + ) + mempty + (children node) + +getFreeVarsInfo :: Node -> FreeVarsInfo +getFreeVarsInfo = fromJust . Info.lookup kFreeVarsInfo . getInfo + +freeVarOccurrences :: Index -> Node -> Int +freeVarOccurrences idx n = fromMaybe 0 (Map.lookup idx (getFreeVarsInfo n ^. infoFreeVars)) diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/LetFolding.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/LetFolding.hs index b10738298..1a35ae56f 100644 --- a/src/Juvix/Compiler/Core/Transformation/Optimize/LetFolding.hs +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/LetFolding.hs @@ -1,18 +1,19 @@ -- An optimizing transformation that folds lets whose values are immediate, -- i.e., they don't require evaluation or memory allocation (variables or --- constants). +-- constants), or when the bound variable occurs at most once in the body. -- -- For example, transforms -- ``` --- let x := y in let z := x + x in x + z +-- let x := y in let z := x + x in let u := z + y in x * x + z * z + u -- ``` -- to -- ``` --- let z := y + y in y + z +-- let z := y + y in y * y + z * z + z + y -- ``` -module Juvix.Compiler.Core.Transformation.Optimize.LetFolding where +module Juvix.Compiler.Core.Transformation.Optimize.LetFolding (letFolding, letFolding') where import Juvix.Compiler.Core.Extra +import Juvix.Compiler.Core.Info.FreeVarsInfo as Info import Juvix.Compiler.Core.Transformation.Base convertNode :: (Node -> Bool) -> InfoTable -> Node -> Node @@ -23,6 +24,7 @@ convertNode isFoldable tab = rmap go NLet Let {..} | isImmediate tab (_letItem ^. letItemValue) || isVarApp _letBody + || Info.freeVarOccurrences 0 _letBody <= 1 || isFoldable (_letItem ^. letItemValue) -> go (recur . (mkBCRemove (_letItem ^. letItemBinder) val' :)) _letBody where @@ -36,7 +38,13 @@ convertNode isFoldable tab = rmap go in h == mkVar' 0 letFolding' :: (Node -> Bool) -> InfoTable -> InfoTable -letFolding' isFoldable tab = mapAllNodes (convertNode isFoldable tab) tab +letFolding' isFoldable tab = + mapAllNodes + ( removeInfo kFreeVarsInfo + . convertNode isFoldable tab + . computeFreeVarsInfo + ) + tab letFolding :: InfoTable -> InfoTable letFolding = letFolding' (const False) diff --git a/tests/Core/positive/test034.jvc b/tests/Core/positive/test034.jvc index de1a2d82f..9811a8475 100644 --- a/tests/Core/positive/test034.jvc +++ b/tests/Core/positive/test034.jvc @@ -6,7 +6,7 @@ def f := \x \y if x = 0 then 9 else trace 1 >>> (f (x - 1) (y 0)); def h := \x trace 8 >>> trace x >>> x + x; -def const := \x \y x; +def const := \x \y y >>> x; type list { nil : list; @@ -15,7 +15,7 @@ type list { trace (const 0 (trace "!" >>> 1)) >>> trace (const 0 (trace "a" >>> cons 1 (trace "b" >>> trace "c" >>> cons 1 (trace "d" >>> nil)))) >>> -trace ((\x \y \z trace "2" >>> x + y + (trace "3" >>> z)) (trace "1" >>> 1) 2 3) >>> +trace ((\x \y \z x >>> trace "2" >>> x + y + (trace "3" >>> z)) (trace "1" >>> 1) 2 3) >>> trace (f 5 g) >>> trace 7 >>> h (trace 2 >>> 3)