mirror of
https://github.com/anoma/juvix.git
synced 2024-11-30 05:42:26 +03:00
Don't fold lets if the let-bound variable occurs under a lambda-abstraction (#3029)
* Closes #3002
This commit is contained in:
parent
ef0bc6efb8
commit
b609e1f6a5
@ -177,52 +177,52 @@ isFalseConstr = \case
|
||||
NCtr Constr {..} | _constrTag == BuiltinTag TagFalse -> True
|
||||
_ -> False
|
||||
|
||||
isDebugOp :: Node -> Bool
|
||||
isDebugOp = \case
|
||||
NBlt BuiltinApp {..} ->
|
||||
case _builtinAppOp of
|
||||
OpTrace -> True
|
||||
OpFail -> True
|
||||
OpSeq -> True
|
||||
OpAssert -> False
|
||||
OpAnomaByteArrayFromAnomaContents -> False
|
||||
OpAnomaByteArrayToAnomaContents -> False
|
||||
OpAnomaDecode -> False
|
||||
OpAnomaEncode -> False
|
||||
OpAnomaGet -> False
|
||||
OpAnomaSign -> False
|
||||
OpAnomaSignDetached -> False
|
||||
OpAnomaVerifyDetached -> False
|
||||
OpAnomaVerifyWithMessage -> False
|
||||
OpEc -> False
|
||||
OpFieldAdd -> False
|
||||
OpFieldDiv -> False
|
||||
OpFieldFromInt -> False
|
||||
OpFieldMul -> False
|
||||
OpFieldSub -> False
|
||||
OpPoseidonHash -> False
|
||||
OpRandomEcPoint -> False
|
||||
OpStrConcat -> False
|
||||
OpStrToInt -> False
|
||||
OpUInt8FromInt -> False
|
||||
OpUInt8ToInt -> False
|
||||
OpByteArrayFromListByte -> False
|
||||
OpByteArrayLength -> False
|
||||
OpEq -> False
|
||||
OpIntAdd -> False
|
||||
OpIntDiv -> False
|
||||
OpIntLe -> False
|
||||
OpIntLt -> False
|
||||
OpIntMod -> False
|
||||
OpIntMul -> False
|
||||
OpIntSub -> False
|
||||
OpFieldToInt -> False
|
||||
OpShow -> False
|
||||
_ -> False
|
||||
|
||||
-- | Check if the node contains `trace`, `fail` or `seq` (`>->`).
|
||||
containsDebugOperations :: Node -> Bool
|
||||
containsDebugOperations = ufold (\x xs -> x || or xs) isDebugOp
|
||||
where
|
||||
isDebugOp :: Node -> Bool
|
||||
isDebugOp = \case
|
||||
NBlt BuiltinApp {..} ->
|
||||
case _builtinAppOp of
|
||||
OpTrace -> True
|
||||
OpFail -> True
|
||||
OpSeq -> True
|
||||
OpAssert -> False
|
||||
OpAnomaByteArrayFromAnomaContents -> False
|
||||
OpAnomaByteArrayToAnomaContents -> False
|
||||
OpAnomaDecode -> False
|
||||
OpAnomaEncode -> False
|
||||
OpAnomaGet -> False
|
||||
OpAnomaSign -> False
|
||||
OpAnomaSignDetached -> False
|
||||
OpAnomaVerifyDetached -> False
|
||||
OpAnomaVerifyWithMessage -> False
|
||||
OpEc -> False
|
||||
OpFieldAdd -> False
|
||||
OpFieldDiv -> False
|
||||
OpFieldFromInt -> False
|
||||
OpFieldMul -> False
|
||||
OpFieldSub -> False
|
||||
OpPoseidonHash -> False
|
||||
OpRandomEcPoint -> False
|
||||
OpStrConcat -> False
|
||||
OpStrToInt -> False
|
||||
OpUInt8FromInt -> False
|
||||
OpUInt8ToInt -> False
|
||||
OpByteArrayFromListByte -> False
|
||||
OpByteArrayLength -> False
|
||||
OpEq -> False
|
||||
OpIntAdd -> False
|
||||
OpIntDiv -> False
|
||||
OpIntLe -> False
|
||||
OpIntLt -> False
|
||||
OpIntMod -> False
|
||||
OpIntMul -> False
|
||||
OpIntSub -> False
|
||||
OpFieldToInt -> False
|
||||
OpShow -> False
|
||||
_ -> False
|
||||
containsDebugOps :: Node -> Bool
|
||||
containsDebugOps = ufold (\x xs -> x || or xs) isDebugOp
|
||||
|
||||
freeVarsSortedMany :: [Node] -> Set Var
|
||||
freeVarsSortedMany n = Set.fromList (n ^.. each . freeVars)
|
||||
|
@ -19,7 +19,12 @@ makeLenses ''FreeVarsInfo
|
||||
-- | Computes free variable info for each subnode. Assumption: no subnode is a
|
||||
-- closure.
|
||||
computeFreeVarsInfo :: Node -> Node
|
||||
computeFreeVarsInfo = umap go
|
||||
computeFreeVarsInfo = computeFreeVarsInfo' 1
|
||||
|
||||
-- | `lambdaMultiplier` specifies how much to multiply the free variable count
|
||||
-- for variables under lambdas
|
||||
computeFreeVarsInfo' :: Int -> Node -> Node
|
||||
computeFreeVarsInfo' lambdaMultiplier = umap go
|
||||
where
|
||||
go :: Node -> Node
|
||||
go node = case node of
|
||||
@ -27,6 +32,13 @@ computeFreeVarsInfo = umap go
|
||||
mkVar (Info.insert fvi _varInfo) _varIndex
|
||||
where
|
||||
fvi = FreeVarsInfo (Map.singleton _varIndex 1)
|
||||
NLam Lambda {..} ->
|
||||
modifyInfo (Info.insert fvi) node
|
||||
where
|
||||
fvi =
|
||||
FreeVarsInfo
|
||||
. fmap (* lambdaMultiplier)
|
||||
$ getFreeVars 1 _lambdaBody
|
||||
_ ->
|
||||
modifyInfo (Info.insert fvi) node
|
||||
where
|
||||
@ -35,14 +47,17 @@ computeFreeVarsInfo = umap go
|
||||
foldr
|
||||
( \NodeChild {..} acc ->
|
||||
Map.unionWith (+) acc $
|
||||
Map.mapKeysMonotonic (\idx -> idx - _childBindersNum) $
|
||||
Map.filterWithKey
|
||||
(\idx _ -> idx >= _childBindersNum)
|
||||
(getFreeVarsInfo _childNode ^. infoFreeVars)
|
||||
getFreeVars _childBindersNum _childNode
|
||||
)
|
||||
mempty
|
||||
(children node)
|
||||
|
||||
getFreeVars :: Int -> Node -> Map Index Int
|
||||
getFreeVars bindersNum node =
|
||||
Map.mapKeysMonotonic (\idx -> idx - bindersNum)
|
||||
. Map.filterWithKey (\idx _ -> idx >= bindersNum)
|
||||
$ getFreeVarsInfo node ^. infoFreeVars
|
||||
|
||||
getFreeVarsInfo :: Node -> FreeVarsInfo
|
||||
getFreeVarsInfo = fromJust . Info.lookup kFreeVarsInfo . getInfo
|
||||
|
||||
|
@ -77,4 +77,4 @@ convertNode = dmap go
|
||||
-- - https://github.com/anoma/juvix/issues/1654
|
||||
-- - https://github.com/anoma/juvix/pull/1659
|
||||
moveApps :: Module -> Module
|
||||
moveApps = mapT (const convertNode)
|
||||
moveApps = mapAllNodes convertNode
|
||||
|
@ -75,6 +75,8 @@ convertNode inlineDepth nonRecSyms md = dmapL go
|
||||
NIdt Ident {..} -> case pi of
|
||||
Just InlineCase ->
|
||||
NCase cs {_caseValue = mkApps def args}
|
||||
Just InlineNever ->
|
||||
node
|
||||
Nothing
|
||||
| HashSet.member _identSymbol nonRecSyms
|
||||
&& isConstructorApp def
|
||||
|
@ -1,6 +1,7 @@
|
||||
-- An optimizing transformation that folds lets whose values are immediate,
|
||||
-- i.e., they don't require evaluation or memory allocation (variables or
|
||||
-- constants), or when the bound variable occurs at most once in the body.
|
||||
-- constants), or when the bound variable occurs at most once in the body but
|
||||
-- not under a lambda-abstraction.
|
||||
--
|
||||
-- For example, transforms
|
||||
-- ```
|
||||
@ -27,7 +28,7 @@ convertNode isFoldable md = rmapL go
|
||||
|| Info.freeVarOccurrences 0 _letBody <= 1
|
||||
|| isFoldable md bl (_letItem ^. letItemValue)
|
||||
)
|
||||
&& not (containsDebugOperations _letBody) ->
|
||||
&& not (containsDebugOps _letBody) ->
|
||||
go (recur . (mkBCRemove b val' :)) (BL.cons b bl) _letBody
|
||||
where
|
||||
val' = go recur bl (_letItem ^. letItemValue)
|
||||
@ -40,7 +41,11 @@ letFolding' isFoldable tab =
|
||||
mapAllNodes
|
||||
( removeInfo kFreeVarsInfo
|
||||
. convertNode isFoldable tab
|
||||
. computeFreeVarsInfo
|
||||
. computeFreeVarsInfo' 2
|
||||
-- 2 is the lambda multiplier factor which guarantees that every free
|
||||
-- variable under a lambda is counted at least twice, preventing let
|
||||
-- folding for let-bound variables (with non-immediate values) that
|
||||
-- occur under lambdas
|
||||
)
|
||||
tab
|
||||
|
||||
|
@ -451,7 +451,7 @@ tests =
|
||||
$(mkRelDir ".")
|
||||
$(mkRelFile "test075.juvix")
|
||||
$(mkRelFile "out/test075.out"),
|
||||
posTestEval
|
||||
posTest
|
||||
"Test076: Builtin Maybe"
|
||||
$(mkRelDir ".")
|
||||
$(mkRelFile "test076.juvix")
|
||||
@ -466,9 +466,19 @@ tests =
|
||||
$(mkRelDir ".")
|
||||
$(mkRelFile "test078.juvix")
|
||||
$(mkRelFile "out/test078.out"),
|
||||
posTestEval
|
||||
posTest
|
||||
"Test079: Let / LetRec type inference (during lambda lifting) in Core"
|
||||
$(mkRelDir ".")
|
||||
$(mkRelFile "test079.juvix")
|
||||
$(mkRelFile "out/test079.out")
|
||||
$(mkRelFile "out/test079.out"),
|
||||
posTestEval -- TODO: this test is not compiling
|
||||
"Test080: Do notation"
|
||||
$(mkRelDir ".")
|
||||
$(mkRelFile "test080.juvix")
|
||||
$(mkRelFile "out/test080.out"),
|
||||
posTest
|
||||
"Test081: Non-duplication in let-folding"
|
||||
$(mkRelDir ".")
|
||||
$(mkRelFile "test081.juvix")
|
||||
$(mkRelFile "out/test081.out")
|
||||
]
|
||||
|
2
tests/Compilation/positive/out/test080.out
Normal file
2
tests/Compilation/positive/out/test080.out
Normal file
@ -0,0 +1,2 @@
|
||||
nothing
|
||||
just 1
|
1
tests/Compilation/positive/out/test081.out
Normal file
1
tests/Compilation/positive/out/test081.out
Normal file
@ -0,0 +1 @@
|
||||
0
|
@ -1,15 +1,15 @@
|
||||
-- builtin list
|
||||
module test059;
|
||||
|
||||
import Stdlib.Prelude open hiding {head};
|
||||
import Stdlib.Prelude open;
|
||||
|
||||
mylist : List Nat := [1; 2; 3 + 1];
|
||||
|
||||
mylist2 : List (List Nat) := [[10]; [2]; 3 + 1 :: nil];
|
||||
|
||||
head : {a : Type} -> a -> List a -> a
|
||||
head' : {a : Type} -> a -> List a -> a
|
||||
| a [] := a
|
||||
| a [x; _] := x
|
||||
| _ (h :: _) := h;
|
||||
|
||||
main : Nat := head 50 mylist + head 50 (head [] mylist2);
|
||||
main : Nat := head' 50 mylist + head' 50 (head' [] mylist2);
|
||||
|
18
tests/Compilation/positive/test081.juvix
Normal file
18
tests/Compilation/positive/test081.juvix
Normal file
@ -0,0 +1,18 @@
|
||||
-- Non-duplication in let-folding
|
||||
module test081;
|
||||
|
||||
import Stdlib.Prelude open;
|
||||
|
||||
{-# inline: false #-}
|
||||
g (h : Nat -> Nat) : Nat := h 0 * h 0;
|
||||
|
||||
terminating
|
||||
f (n : Nat) : Nat :=
|
||||
if
|
||||
| n == 0 := 0
|
||||
| else :=
|
||||
let terminating x := f (sub n 1)
|
||||
in
|
||||
g \{_ := x};
|
||||
|
||||
main : Nat := f 10000;
|
Loading…
Reference in New Issue
Block a user