1
1
mirror of https://github.com/anoma/juvix.git synced 2024-11-30 14:13:27 +03:00

Fold lets when the bound variable occurs at most once (#2231)

For example, convert
```
let x := f a b c in
g x
```
to
```
g (f a b c)
```
This commit is contained in:
Łukasz Czajka 2023-06-29 13:02:10 +02:00 committed by GitHub
parent b4347bdd23
commit 39ef069bfc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 64 additions and 7 deletions

View File

@ -0,0 +1,49 @@
module Juvix.Compiler.Core.Info.FreeVarsInfo where
import Data.Map qualified as Map
import Juvix.Compiler.Core.Extra
import Juvix.Compiler.Core.Info qualified as Info
import Juvix.Compiler.Core.Language
newtype FreeVarsInfo = FreeVarsInfo
{ -- map free variables to the number of their occurrences
_infoFreeVars :: Map Index Int
}
instance IsInfo FreeVarsInfo
kFreeVarsInfo :: Key FreeVarsInfo
kFreeVarsInfo = Proxy
makeLenses ''FreeVarsInfo
computeFreeVarsInfo :: Node -> Node
computeFreeVarsInfo = umap go
where
go :: Node -> Node
go node = case node of
NVar Var {..} ->
mkVar (Info.insert fvi _varInfo) _varIndex
where
fvi = FreeVarsInfo (Map.singleton _varIndex 1)
_ ->
modifyInfo (Info.insert fvi) node
where
fvi =
FreeVarsInfo $
foldr
( \NodeChild {..} acc ->
Map.unionWith (+) acc $
Map.mapKeysMonotonic (\idx -> idx - _childBindersNum) $
Map.filterWithKey
(\idx _ -> idx >= _childBindersNum)
(getFreeVarsInfo _childNode ^. infoFreeVars)
)
mempty
(children node)
getFreeVarsInfo :: Node -> FreeVarsInfo
getFreeVarsInfo = fromJust . Info.lookup kFreeVarsInfo . getInfo
freeVarOccurrences :: Index -> Node -> Int
freeVarOccurrences idx n = fromMaybe 0 (Map.lookup idx (getFreeVarsInfo n ^. infoFreeVars))

View File

@ -1,18 +1,19 @@
-- An optimizing transformation that folds lets whose values are immediate,
-- i.e., they don't require evaluation or memory allocation (variables or
-- constants).
-- constants), or when the bound variable occurs at most once in the body.
--
-- For example, transforms
-- ```
-- let x := y in let z := x + x in x + z
-- let x := y in let z := x + x in let u := z + y in x * x + z * z + u
-- ```
-- to
-- ```
-- let z := y + y in y + z
-- let z := y + y in y * y + z * z + z + y
-- ```
module Juvix.Compiler.Core.Transformation.Optimize.LetFolding where
module Juvix.Compiler.Core.Transformation.Optimize.LetFolding (letFolding, letFolding') where
import Juvix.Compiler.Core.Extra
import Juvix.Compiler.Core.Info.FreeVarsInfo as Info
import Juvix.Compiler.Core.Transformation.Base
convertNode :: (Node -> Bool) -> InfoTable -> Node -> Node
@ -23,6 +24,7 @@ convertNode isFoldable tab = rmap go
NLet Let {..}
| isImmediate tab (_letItem ^. letItemValue)
|| isVarApp _letBody
|| Info.freeVarOccurrences 0 _letBody <= 1
|| isFoldable (_letItem ^. letItemValue) ->
go (recur . (mkBCRemove (_letItem ^. letItemBinder) val' :)) _letBody
where
@ -36,7 +38,13 @@ convertNode isFoldable tab = rmap go
in h == mkVar' 0
letFolding' :: (Node -> Bool) -> InfoTable -> InfoTable
letFolding' isFoldable tab = mapAllNodes (convertNode isFoldable tab) tab
letFolding' isFoldable tab =
mapAllNodes
( removeInfo kFreeVarsInfo
. convertNode isFoldable tab
. computeFreeVarsInfo
)
tab
letFolding :: InfoTable -> InfoTable
letFolding = letFolding' (const False)

View File

@ -6,7 +6,7 @@ def f := \x \y if x = 0 then 9 else trace 1 >>> (f (x - 1) (y 0));
def h := \x trace 8 >>> trace x >>> x + x;
def const := \x \y x;
def const := \x \y y >>> x;
type list {
nil : list;
@ -15,7 +15,7 @@ type list {
trace (const 0 (trace "!" >>> 1)) >>>
trace (const 0 (trace "a" >>> cons 1 (trace "b" >>> trace "c" >>> cons 1 (trace "d" >>> nil)))) >>>
trace ((\x \y \z trace "2" >>> x + y + (trace "3" >>> z)) (trace "1" >>> 1) 2 3) >>>
trace ((\x \y \z x >>> trace "2" >>> x + y + (trace "3" >>> z)) (trace "1" >>> 1) 2 3) >>>
trace (f 5 g) >>>
trace 7 >>>
h (trace 2 >>> 3)