From 86c18f37afa0192c8adf863b686a80f329ac12d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Czajka?= <62751+lukaszcz@users.noreply.github.com> Date: Fri, 24 Mar 2023 12:35:47 +0100 Subject: [PATCH] Let folding (#1921) * Closes #1899 --- .../Compiler/Core/Data/TransformationId.hs | 3 +- .../Core/Data/TransformationId/Parser.hs | 4 +++ .../Compiler/Core/Extra/Recursors/RMap.hs | 3 ++ src/Juvix/Compiler/Core/Extra/Utils.hs | 9 ++++++ src/Juvix/Compiler/Core/Transformation.hs | 2 ++ .../Transformation/Optimize/LetFolding.hs | 32 +++++++++++++++++++ 6 files changed, 52 insertions(+), 1 deletion(-) create mode 100644 src/Juvix/Compiler/Core/Transformation/Optimize/LetFolding.hs diff --git a/src/Juvix/Compiler/Core/Data/TransformationId.hs b/src/Juvix/Compiler/Core/Data/TransformationId.hs index d9559d722..cd8ee3248 100644 --- a/src/Juvix/Compiler/Core/Data/TransformationId.hs +++ b/src/Juvix/Compiler/Core/Data/TransformationId.hs @@ -17,6 +17,7 @@ data TransformationId | EtaExpandApps | DisambiguateNames | CheckGeb + | LetFolding deriving stock (Data, Bounded, Enum) data PipelineId @@ -51,7 +52,7 @@ toGebTransformations :: [TransformationId] toGebTransformations = toEvalTransformations ++ [LetRecLifting, CheckGeb, UnrollRecursion, ComputeTypeInfo] toEvalTransformations :: [TransformationId] -toEvalTransformations = [EtaExpandApps, MatchToCase, NatToInt, ConvertBuiltinTypes] +toEvalTransformations = [EtaExpandApps, MatchToCase, NatToInt, ConvertBuiltinTypes, LetFolding] pipeline :: PipelineId -> [TransformationId] pipeline = \case diff --git a/src/Juvix/Compiler/Core/Data/TransformationId/Parser.hs b/src/Juvix/Compiler/Core/Data/TransformationId/Parser.hs index 1979b1623..280025c50 100644 --- a/src/Juvix/Compiler/Core/Data/TransformationId/Parser.hs +++ b/src/Juvix/Compiler/Core/Data/TransformationId/Parser.hs @@ -81,6 +81,7 @@ transformationText = \case UnrollRecursion -> strUnrollRecursion DisambiguateNames -> strDisambiguateNames CheckGeb -> strCheckGeb + LetFolding -> strLetFolding parsePipeline :: MonadParsec e Text m => m PipelineId parsePipeline = choice [symbol (pipelineText t) $> t | t <- allElements] @@ -141,3 +142,6 @@ strDisambiguateNames = "disambiguate-names" strCheckGeb :: Text strCheckGeb = "check-geb" + +strLetFolding :: Text +strLetFolding = "let-folding" diff --git a/src/Juvix/Compiler/Core/Extra/Recursors/RMap.hs b/src/Juvix/Compiler/Core/Extra/Recursors/RMap.hs index 6c223c301..b4951398a 100644 --- a/src/Juvix/Compiler/Core/Extra/Recursors/RMap.hs +++ b/src/Juvix/Compiler/Core/Extra/Recursors/RMap.hs @@ -23,6 +23,9 @@ data BinderChange -- indices of `n` are with respect to the result BCRemove BinderRemove +mkBCRemove :: Binder -> Node -> BinderChange +mkBCRemove b n = BCRemove (BinderRemove b n) + -- | Returns the binders in the original node skipped before a call to `recur`, -- as specified by the BinderChange list. bindersFromBinderChange :: [BinderChange] -> [Binder] diff --git a/src/Juvix/Compiler/Core/Extra/Utils.hs b/src/Juvix/Compiler/Core/Extra/Utils.hs index e4a7aa03a..544bff49e 100644 --- a/src/Juvix/Compiler/Core/Extra/Utils.hs +++ b/src/Juvix/Compiler/Core/Extra/Utils.hs @@ -39,6 +39,15 @@ isTypeConstr tab ty = case typeTarget ty of isTypeConstr tab (fromJust $ HashMap.lookup _identSymbol (tab ^. identContext)) _ -> False +-- True for nodes whose evaluation immediately returns a constant value, i.e., +-- no reduction or memory allocation in the runtime is required. +isImmediate :: Node -> Bool +isImmediate = \case + NVar {} -> True + NIdt {} -> True + NCst {} -> True + _ -> False + freeVarsSorted :: Node -> Set Var freeVarsSorted n = Set.fromList (n ^.. freeVars) diff --git a/src/Juvix/Compiler/Core/Transformation.hs b/src/Juvix/Compiler/Core/Transformation.hs index 462e05476..96013f552 100644 --- a/src/Juvix/Compiler/Core/Transformation.hs +++ b/src/Juvix/Compiler/Core/Transformation.hs @@ -21,6 +21,7 @@ import Juvix.Compiler.Core.Transformation.LambdaLetRecLifting import Juvix.Compiler.Core.Transformation.MatchToCase import Juvix.Compiler.Core.Transformation.MoveApps import Juvix.Compiler.Core.Transformation.NatToInt +import Juvix.Compiler.Core.Transformation.Optimize.LetFolding import Juvix.Compiler.Core.Transformation.RemoveTypeArgs import Juvix.Compiler.Core.Transformation.TopEtaExpand import Juvix.Compiler.Core.Transformation.UnrollRecursion @@ -44,3 +45,4 @@ applyTransformations ts tbl = foldl' (\acc tid -> acc >>= appTrans tid) (return EtaExpandApps -> return . etaExpansionApps DisambiguateNames -> return . disambiguateNames CheckGeb -> mapError (JuvixError @CoreError) . checkGeb + LetFolding -> return . letFolding diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/LetFolding.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/LetFolding.hs new file mode 100644 index 000000000..2e52cc260 --- /dev/null +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/LetFolding.hs @@ -0,0 +1,32 @@ +-- An optimizing transformation that folds lets whose values are immediate, +-- i.e., they don't require evaluation or memory allocation (variables or +-- constants). +-- +-- For example, transforms +-- ``` +-- let x := y in let z := x + x in x + z +-- ``` +-- to +-- ``` +-- let z := y + y in y + z +-- ``` +module Juvix.Compiler.Core.Transformation.Optimize.LetFolding where + +import Juvix.Compiler.Core.Extra +import Juvix.Compiler.Core.Transformation.Base + +convertNode :: Node -> Node +convertNode = rmap go + where + go :: ([BinderChange] -> Node -> Node) -> Node -> Node + go recur = \case + NLet Let {..} + | isImmediate (_letItem ^. letItemValue) -> + go (recur . (mkBCRemove (_letItem ^. letItemBinder) val' :)) _letBody + where + val' = go recur (_letItem ^. letItemValue) + node -> + recur [] node + +letFolding :: InfoTable -> InfoTable +letFolding = mapAllNodes convertNode