diff --git a/src/Juvix/Compiler/Core/Extra/Utils.hs b/src/Juvix/Compiler/Core/Extra/Utils.hs index 33756a7c8..de26939d9 100644 --- a/src/Juvix/Compiler/Core/Extra/Utils.hs +++ b/src/Juvix/Compiler/Core/Extra/Utils.hs @@ -177,52 +177,52 @@ isFalseConstr = \case NCtr Constr {..} | _constrTag == BuiltinTag TagFalse -> True _ -> False +isDebugOp :: Node -> Bool +isDebugOp = \case + NBlt BuiltinApp {..} -> + case _builtinAppOp of + OpTrace -> True + OpFail -> True + OpSeq -> True + OpAssert -> False + OpAnomaByteArrayFromAnomaContents -> False + OpAnomaByteArrayToAnomaContents -> False + OpAnomaDecode -> False + OpAnomaEncode -> False + OpAnomaGet -> False + OpAnomaSign -> False + OpAnomaSignDetached -> False + OpAnomaVerifyDetached -> False + OpAnomaVerifyWithMessage -> False + OpEc -> False + OpFieldAdd -> False + OpFieldDiv -> False + OpFieldFromInt -> False + OpFieldMul -> False + OpFieldSub -> False + OpPoseidonHash -> False + OpRandomEcPoint -> False + OpStrConcat -> False + OpStrToInt -> False + OpUInt8FromInt -> False + OpUInt8ToInt -> False + OpByteArrayFromListByte -> False + OpByteArrayLength -> False + OpEq -> False + OpIntAdd -> False + OpIntDiv -> False + OpIntLe -> False + OpIntLt -> False + OpIntMod -> False + OpIntMul -> False + OpIntSub -> False + OpFieldToInt -> False + OpShow -> False + _ -> False + -- | Check if the node contains `trace`, `fail` or `seq` (`>->`). -containsDebugOperations :: Node -> Bool -containsDebugOperations = ufold (\x xs -> x || or xs) isDebugOp - where - isDebugOp :: Node -> Bool - isDebugOp = \case - NBlt BuiltinApp {..} -> - case _builtinAppOp of - OpTrace -> True - OpFail -> True - OpSeq -> True - OpAssert -> False - OpAnomaByteArrayFromAnomaContents -> False - OpAnomaByteArrayToAnomaContents -> False - OpAnomaDecode -> False - OpAnomaEncode -> False - OpAnomaGet -> False - OpAnomaSign -> False - OpAnomaSignDetached -> False - OpAnomaVerifyDetached -> False - OpAnomaVerifyWithMessage -> False - OpEc -> False - OpFieldAdd -> False - OpFieldDiv -> False - OpFieldFromInt -> False - OpFieldMul -> False - OpFieldSub -> False - OpPoseidonHash -> False - OpRandomEcPoint -> False - OpStrConcat -> False - OpStrToInt -> False - OpUInt8FromInt -> False - OpUInt8ToInt -> False - OpByteArrayFromListByte -> False - OpByteArrayLength -> False - OpEq -> False - OpIntAdd -> False - OpIntDiv -> False - OpIntLe -> False - OpIntLt -> False - OpIntMod -> False - OpIntMul -> False - OpIntSub -> False - OpFieldToInt -> False - OpShow -> False - _ -> False +containsDebugOps :: Node -> Bool +containsDebugOps = ufold (\x xs -> x || or xs) isDebugOp freeVarsSortedMany :: [Node] -> Set Var freeVarsSortedMany n = Set.fromList (n ^.. each . freeVars) diff --git a/src/Juvix/Compiler/Core/Info/FreeVarsInfo.hs b/src/Juvix/Compiler/Core/Info/FreeVarsInfo.hs index 2a665ab4e..312c6cfe6 100644 --- a/src/Juvix/Compiler/Core/Info/FreeVarsInfo.hs +++ b/src/Juvix/Compiler/Core/Info/FreeVarsInfo.hs @@ -19,7 +19,12 @@ makeLenses ''FreeVarsInfo -- | Computes free variable info for each subnode. Assumption: no subnode is a -- closure. computeFreeVarsInfo :: Node -> Node -computeFreeVarsInfo = umap go +computeFreeVarsInfo = computeFreeVarsInfo' 1 + +-- | `lambdaMultiplier` specifies how much to multiply the free variable count +-- for variables under lambdas +computeFreeVarsInfo' :: Int -> Node -> Node +computeFreeVarsInfo' lambdaMultiplier = umap go where go :: Node -> Node go node = case node of @@ -27,6 +32,13 @@ computeFreeVarsInfo = umap go mkVar (Info.insert fvi _varInfo) _varIndex where fvi = FreeVarsInfo (Map.singleton _varIndex 1) + NLam Lambda {..} -> + modifyInfo (Info.insert fvi) node + where + fvi = + FreeVarsInfo + . fmap (* lambdaMultiplier) + $ getFreeVars 1 _lambdaBody _ -> modifyInfo (Info.insert fvi) node where @@ -35,14 +47,17 @@ computeFreeVarsInfo = umap go foldr ( \NodeChild {..} acc -> Map.unionWith (+) acc $ - Map.mapKeysMonotonic (\idx -> idx - _childBindersNum) $ - Map.filterWithKey - (\idx _ -> idx >= _childBindersNum) - (getFreeVarsInfo _childNode ^. infoFreeVars) + getFreeVars _childBindersNum _childNode ) mempty (children node) + getFreeVars :: Int -> Node -> Map Index Int + getFreeVars bindersNum node = + Map.mapKeysMonotonic (\idx -> idx - bindersNum) + . Map.filterWithKey (\idx _ -> idx >= bindersNum) + $ getFreeVarsInfo node ^. infoFreeVars + getFreeVarsInfo :: Node -> FreeVarsInfo getFreeVarsInfo = fromJust . Info.lookup kFreeVarsInfo . getInfo diff --git a/src/Juvix/Compiler/Core/Transformation/MoveApps.hs b/src/Juvix/Compiler/Core/Transformation/MoveApps.hs index 187d46c48..64757d039 100644 --- a/src/Juvix/Compiler/Core/Transformation/MoveApps.hs +++ b/src/Juvix/Compiler/Core/Transformation/MoveApps.hs @@ -77,4 +77,4 @@ convertNode = dmap go -- - https://github.com/anoma/juvix/issues/1654 -- - https://github.com/anoma/juvix/pull/1659 moveApps :: Module -> Module -moveApps = mapT (const convertNode) +moveApps = mapAllNodes convertNode diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/Inlining.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/Inlining.hs index 1dccadbff..c2cc7c8af 100644 --- a/src/Juvix/Compiler/Core/Transformation/Optimize/Inlining.hs +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/Inlining.hs @@ -75,6 +75,8 @@ convertNode inlineDepth nonRecSyms md = dmapL go NIdt Ident {..} -> case pi of Just InlineCase -> NCase cs {_caseValue = mkApps def args} + Just InlineNever -> + node Nothing | HashSet.member _identSymbol nonRecSyms && isConstructorApp def diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/LetFolding.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/LetFolding.hs index a84fd6f60..c8752b139 100644 --- a/src/Juvix/Compiler/Core/Transformation/Optimize/LetFolding.hs +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/LetFolding.hs @@ -1,6 +1,7 @@ -- An optimizing transformation that folds lets whose values are immediate, -- i.e., they don't require evaluation or memory allocation (variables or --- constants), or when the bound variable occurs at most once in the body. +-- constants), or when the bound variable occurs at most once in the body but +-- not under a lambda-abstraction. -- -- For example, transforms -- ``` @@ -27,7 +28,7 @@ convertNode isFoldable md = rmapL go || Info.freeVarOccurrences 0 _letBody <= 1 || isFoldable md bl (_letItem ^. letItemValue) ) - && not (containsDebugOperations _letBody) -> + && not (containsDebugOps _letBody) -> go (recur . (mkBCRemove b val' :)) (BL.cons b bl) _letBody where val' = go recur bl (_letItem ^. letItemValue) @@ -40,7 +41,11 @@ letFolding' isFoldable tab = mapAllNodes ( removeInfo kFreeVarsInfo . convertNode isFoldable tab - . computeFreeVarsInfo + . computeFreeVarsInfo' 2 + -- 2 is the lambda multiplier factor which guarantees that every free + -- variable under a lambda is counted at least twice, preventing let + -- folding for let-bound variables (with non-immediate values) that + -- occur under lambdas ) tab diff --git a/test/Compilation/Positive.hs b/test/Compilation/Positive.hs index 51fc1fb28..a7edbb72c 100644 --- a/test/Compilation/Positive.hs +++ b/test/Compilation/Positive.hs @@ -451,7 +451,7 @@ tests = $(mkRelDir ".") $(mkRelFile "test075.juvix") $(mkRelFile "out/test075.out"), - posTestEval + posTest "Test076: Builtin Maybe" $(mkRelDir ".") $(mkRelFile "test076.juvix") @@ -466,9 +466,19 @@ tests = $(mkRelDir ".") $(mkRelFile "test078.juvix") $(mkRelFile "out/test078.out"), - posTestEval + posTest "Test079: Let / LetRec type inference (during lambda lifting) in Core" $(mkRelDir ".") $(mkRelFile "test079.juvix") - $(mkRelFile "out/test079.out") + $(mkRelFile "out/test079.out"), + posTestEval -- TODO: this test is not compiling + "Test080: Do notation" + $(mkRelDir ".") + $(mkRelFile "test080.juvix") + $(mkRelFile "out/test080.out"), + posTest + "Test081: Non-duplication in let-folding" + $(mkRelDir ".") + $(mkRelFile "test081.juvix") + $(mkRelFile "out/test081.out") ] diff --git a/tests/Compilation/positive/out/test080.out b/tests/Compilation/positive/out/test080.out new file mode 100644 index 000000000..66c0e977e --- /dev/null +++ b/tests/Compilation/positive/out/test080.out @@ -0,0 +1,2 @@ +nothing +just 1 diff --git a/tests/Compilation/positive/out/test081.out b/tests/Compilation/positive/out/test081.out new file mode 100644 index 000000000..573541ac9 --- /dev/null +++ b/tests/Compilation/positive/out/test081.out @@ -0,0 +1 @@ +0 diff --git a/tests/Compilation/positive/test059.juvix b/tests/Compilation/positive/test059.juvix index 03a5be832..91ea8dcbc 100644 --- a/tests/Compilation/positive/test059.juvix +++ b/tests/Compilation/positive/test059.juvix @@ -1,15 +1,15 @@ -- builtin list module test059; -import Stdlib.Prelude open hiding {head}; +import Stdlib.Prelude open; mylist : List Nat := [1; 2; 3 + 1]; mylist2 : List (List Nat) := [[10]; [2]; 3 + 1 :: nil]; -head : {a : Type} -> a -> List a -> a +head' : {a : Type} -> a -> List a -> a | a [] := a | a [x; _] := x | _ (h :: _) := h; -main : Nat := head 50 mylist + head 50 (head [] mylist2); +main : Nat := head' 50 mylist + head' 50 (head' [] mylist2); diff --git a/tests/Compilation/positive/test081.juvix b/tests/Compilation/positive/test081.juvix new file mode 100644 index 000000000..73cf7255d --- /dev/null +++ b/tests/Compilation/positive/test081.juvix @@ -0,0 +1,18 @@ +-- Non-duplication in let-folding +module test081; + +import Stdlib.Prelude open; + +{-# inline: false #-} +g (h : Nat -> Nat) : Nat := h 0 * h 0; + +terminating +f (n : Nat) : Nat := + if + | n == 0 := 0 + | else := + let terminating x := f (sub n 1) + in + g \{_ := x}; + +main : Nat := f 10000;