1
1
mirror of https://github.com/anoma/juvix.git synced 2024-11-30 05:42:26 +03:00

Don't fold lets if the let-bound variable occurs under a lambda-abstraction (#3029)

* Closes #3002
This commit is contained in:
Łukasz Czajka 2024-09-13 19:29:39 +02:00 committed by GitHub
parent ef0bc6efb8
commit b609e1f6a5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 113 additions and 60 deletions

View File

@ -177,52 +177,52 @@ isFalseConstr = \case
NCtr Constr {..} | _constrTag == BuiltinTag TagFalse -> True
_ -> False
isDebugOp :: Node -> Bool
isDebugOp = \case
NBlt BuiltinApp {..} ->
case _builtinAppOp of
OpTrace -> True
OpFail -> True
OpSeq -> True
OpAssert -> False
OpAnomaByteArrayFromAnomaContents -> False
OpAnomaByteArrayToAnomaContents -> False
OpAnomaDecode -> False
OpAnomaEncode -> False
OpAnomaGet -> False
OpAnomaSign -> False
OpAnomaSignDetached -> False
OpAnomaVerifyDetached -> False
OpAnomaVerifyWithMessage -> False
OpEc -> False
OpFieldAdd -> False
OpFieldDiv -> False
OpFieldFromInt -> False
OpFieldMul -> False
OpFieldSub -> False
OpPoseidonHash -> False
OpRandomEcPoint -> False
OpStrConcat -> False
OpStrToInt -> False
OpUInt8FromInt -> False
OpUInt8ToInt -> False
OpByteArrayFromListByte -> False
OpByteArrayLength -> False
OpEq -> False
OpIntAdd -> False
OpIntDiv -> False
OpIntLe -> False
OpIntLt -> False
OpIntMod -> False
OpIntMul -> False
OpIntSub -> False
OpFieldToInt -> False
OpShow -> False
_ -> False
-- | Check if the node contains `trace`, `fail` or `seq` (`>->`).
containsDebugOperations :: Node -> Bool
containsDebugOperations = ufold (\x xs -> x || or xs) isDebugOp
where
isDebugOp :: Node -> Bool
isDebugOp = \case
NBlt BuiltinApp {..} ->
case _builtinAppOp of
OpTrace -> True
OpFail -> True
OpSeq -> True
OpAssert -> False
OpAnomaByteArrayFromAnomaContents -> False
OpAnomaByteArrayToAnomaContents -> False
OpAnomaDecode -> False
OpAnomaEncode -> False
OpAnomaGet -> False
OpAnomaSign -> False
OpAnomaSignDetached -> False
OpAnomaVerifyDetached -> False
OpAnomaVerifyWithMessage -> False
OpEc -> False
OpFieldAdd -> False
OpFieldDiv -> False
OpFieldFromInt -> False
OpFieldMul -> False
OpFieldSub -> False
OpPoseidonHash -> False
OpRandomEcPoint -> False
OpStrConcat -> False
OpStrToInt -> False
OpUInt8FromInt -> False
OpUInt8ToInt -> False
OpByteArrayFromListByte -> False
OpByteArrayLength -> False
OpEq -> False
OpIntAdd -> False
OpIntDiv -> False
OpIntLe -> False
OpIntLt -> False
OpIntMod -> False
OpIntMul -> False
OpIntSub -> False
OpFieldToInt -> False
OpShow -> False
_ -> False
containsDebugOps :: Node -> Bool
containsDebugOps = ufold (\x xs -> x || or xs) isDebugOp
freeVarsSortedMany :: [Node] -> Set Var
freeVarsSortedMany n = Set.fromList (n ^.. each . freeVars)

View File

@ -19,7 +19,12 @@ makeLenses ''FreeVarsInfo
-- | Computes free variable info for each subnode. Assumption: no subnode is a
-- closure.
computeFreeVarsInfo :: Node -> Node
computeFreeVarsInfo = umap go
computeFreeVarsInfo = computeFreeVarsInfo' 1
-- | `lambdaMultiplier` specifies how much to multiply the free variable count
-- for variables under lambdas
computeFreeVarsInfo' :: Int -> Node -> Node
computeFreeVarsInfo' lambdaMultiplier = umap go
where
go :: Node -> Node
go node = case node of
@ -27,6 +32,13 @@ computeFreeVarsInfo = umap go
mkVar (Info.insert fvi _varInfo) _varIndex
where
fvi = FreeVarsInfo (Map.singleton _varIndex 1)
NLam Lambda {..} ->
modifyInfo (Info.insert fvi) node
where
fvi =
FreeVarsInfo
. fmap (* lambdaMultiplier)
$ getFreeVars 1 _lambdaBody
_ ->
modifyInfo (Info.insert fvi) node
where
@ -35,14 +47,17 @@ computeFreeVarsInfo = umap go
foldr
( \NodeChild {..} acc ->
Map.unionWith (+) acc $
Map.mapKeysMonotonic (\idx -> idx - _childBindersNum) $
Map.filterWithKey
(\idx _ -> idx >= _childBindersNum)
(getFreeVarsInfo _childNode ^. infoFreeVars)
getFreeVars _childBindersNum _childNode
)
mempty
(children node)
getFreeVars :: Int -> Node -> Map Index Int
getFreeVars bindersNum node =
Map.mapKeysMonotonic (\idx -> idx - bindersNum)
. Map.filterWithKey (\idx _ -> idx >= bindersNum)
$ getFreeVarsInfo node ^. infoFreeVars
getFreeVarsInfo :: Node -> FreeVarsInfo
getFreeVarsInfo = fromJust . Info.lookup kFreeVarsInfo . getInfo

View File

@ -77,4 +77,4 @@ convertNode = dmap go
-- - https://github.com/anoma/juvix/issues/1654
-- - https://github.com/anoma/juvix/pull/1659
moveApps :: Module -> Module
moveApps = mapT (const convertNode)
moveApps = mapAllNodes convertNode

View File

@ -75,6 +75,8 @@ convertNode inlineDepth nonRecSyms md = dmapL go
NIdt Ident {..} -> case pi of
Just InlineCase ->
NCase cs {_caseValue = mkApps def args}
Just InlineNever ->
node
Nothing
| HashSet.member _identSymbol nonRecSyms
&& isConstructorApp def

View File

@ -1,6 +1,7 @@
-- An optimizing transformation that folds lets whose values are immediate,
-- i.e., they don't require evaluation or memory allocation (variables or
-- constants), or when the bound variable occurs at most once in the body.
-- constants), or when the bound variable occurs at most once in the body but
-- not under a lambda-abstraction.
--
-- For example, transforms
-- ```
@ -27,7 +28,7 @@ convertNode isFoldable md = rmapL go
|| Info.freeVarOccurrences 0 _letBody <= 1
|| isFoldable md bl (_letItem ^. letItemValue)
)
&& not (containsDebugOperations _letBody) ->
&& not (containsDebugOps _letBody) ->
go (recur . (mkBCRemove b val' :)) (BL.cons b bl) _letBody
where
val' = go recur bl (_letItem ^. letItemValue)
@ -40,7 +41,11 @@ letFolding' isFoldable tab =
mapAllNodes
( removeInfo kFreeVarsInfo
. convertNode isFoldable tab
. computeFreeVarsInfo
. computeFreeVarsInfo' 2
-- 2 is the lambda multiplier factor which guarantees that every free
-- variable under a lambda is counted at least twice, preventing let
-- folding for let-bound variables (with non-immediate values) that
-- occur under lambdas
)
tab

View File

@ -451,7 +451,7 @@ tests =
$(mkRelDir ".")
$(mkRelFile "test075.juvix")
$(mkRelFile "out/test075.out"),
posTestEval
posTest
"Test076: Builtin Maybe"
$(mkRelDir ".")
$(mkRelFile "test076.juvix")
@ -466,9 +466,19 @@ tests =
$(mkRelDir ".")
$(mkRelFile "test078.juvix")
$(mkRelFile "out/test078.out"),
posTestEval
posTest
"Test079: Let / LetRec type inference (during lambda lifting) in Core"
$(mkRelDir ".")
$(mkRelFile "test079.juvix")
$(mkRelFile "out/test079.out")
$(mkRelFile "out/test079.out"),
posTestEval -- TODO: this test is not compiling
"Test080: Do notation"
$(mkRelDir ".")
$(mkRelFile "test080.juvix")
$(mkRelFile "out/test080.out"),
posTest
"Test081: Non-duplication in let-folding"
$(mkRelDir ".")
$(mkRelFile "test081.juvix")
$(mkRelFile "out/test081.out")
]

View File

@ -0,0 +1,2 @@
nothing
just 1

View File

@ -0,0 +1 @@
0

View File

@ -1,15 +1,15 @@
-- builtin list
module test059;
import Stdlib.Prelude open hiding {head};
import Stdlib.Prelude open;
mylist : List Nat := [1; 2; 3 + 1];
mylist2 : List (List Nat) := [[10]; [2]; 3 + 1 :: nil];
head : {a : Type} -> a -> List a -> a
head' : {a : Type} -> a -> List a -> a
| a [] := a
| a [x; _] := x
| _ (h :: _) := h;
main : Nat := head 50 mylist + head 50 (head [] mylist2);
main : Nat := head' 50 mylist + head' 50 (head' [] mylist2);

View File

@ -0,0 +1,18 @@
-- Non-duplication in let-folding
module test081;
import Stdlib.Prelude open;
{-# inline: false #-}
g (h : Nat -> Nat) : Nat := h 0 * h 0;
terminating
f (n : Nat) : Nat :=
if
| n == 0 := 0
| else :=
let terminating x := f (sub n 1)
in
g \{_ := x};
main : Nat := f 10000;