From 7167cb319ab7f612a21e3f03b0798293ce0ec79f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Czajka?= <62751+lukaszcz@users.noreply.github.com> Date: Mon, 9 Sep 2024 14:56:36 +0200 Subject: [PATCH] Lift non-immediate expressions out of case values for the Nockma backend (#3010) Implements a transformation `compute-case-anf` which lifts out non-immediate values matched on in case expressions by introducing let-bindings for them. In essence, this is a partial ANF transformation for case expressions only. For example, transforms ``` case f x of { c y := y + x; d y := y } ``` to ``` let z := f x in case z of { c y := y + x; d y := y } ``` This transformation is needed to avoid duplication of values matched on in case-expressions in the Nockma backend. --- app/Commands/Dev/Core/Compile/Base.hs | 2 +- .../Compiler/Core/Data/TransformationId.hs | 2 + .../Core/Data/TransformationId/Strings.hs | 3 + src/Juvix/Compiler/Core/Pipeline.hs | 6 ++ src/Juvix/Compiler/Core/Transformation.hs | 2 + .../Core/Transformation/ComputeCaseANF.hs | 62 +++++++++++++++++++ src/Juvix/Compiler/Pipeline.hs | 26 ++++---- 7 files changed, 91 insertions(+), 12 deletions(-) create mode 100644 src/Juvix/Compiler/Core/Transformation/ComputeCaseANF.hs diff --git a/app/Commands/Dev/Core/Compile/Base.hs b/app/Commands/Dev/Core/Compile/Base.hs index 9ae05a596..588a9dd46 100644 --- a/app/Commands/Dev/Core/Compile/Base.hs +++ b/app/Commands/Dev/Core/Compile/Base.hs @@ -131,7 +131,7 @@ runTreePipeline pa@PipelineArg {..} = do r <- runReader entryPoint . runError @JuvixError - . coreToTree Core.IdentityTrans + . coreToTree Core.IdentityTrans [] $ _pipelineArgModule tab' <- getRight r let code = Tree.ppPrint tab' tab' diff --git a/src/Juvix/Compiler/Core/Data/TransformationId.hs b/src/Juvix/Compiler/Core/Data/TransformationId.hs index edc1a8904..edbb15447 100644 --- a/src/Juvix/Compiler/Core/Data/TransformationId.hs +++ b/src/Juvix/Compiler/Core/Data/TransformationId.hs @@ -16,6 +16,7 @@ data TransformationId | IdentityTrans | UnrollRecursion | ComputeTypeInfo + | ComputeCaseANF | MatchToCase | EtaExpandApps | DisambiguateNames @@ -91,6 +92,7 @@ instance TransformationId' TransformationId where IntToPrimInt -> strIntToPrimInt ConvertBuiltinTypes -> strConvertBuiltinTypes ComputeTypeInfo -> strComputeTypeInfo + ComputeCaseANF -> strComputeCaseANF UnrollRecursion -> strUnrollRecursion DisambiguateNames -> strDisambiguateNames CombineInfoTables -> strCombineInfoTables diff --git a/src/Juvix/Compiler/Core/Data/TransformationId/Strings.hs b/src/Juvix/Compiler/Core/Data/TransformationId/Strings.hs index adaa1d0c7..0ae7b475b 100644 --- a/src/Juvix/Compiler/Core/Data/TransformationId/Strings.hs +++ b/src/Juvix/Compiler/Core/Data/TransformationId/Strings.hs @@ -56,6 +56,9 @@ strConvertBuiltinTypes = "convert-builtin-types" strComputeTypeInfo :: Text strComputeTypeInfo = "compute-type-info" +strComputeCaseANF :: Text +strComputeCaseANF = "compute-case-anf" + strUnrollRecursion :: Text strUnrollRecursion = "unroll-recursion" diff --git a/src/Juvix/Compiler/Core/Pipeline.hs b/src/Juvix/Compiler/Core/Pipeline.hs index a784303e8..4ec9024f0 100644 --- a/src/Juvix/Compiler/Core/Pipeline.hs +++ b/src/Juvix/Compiler/Core/Pipeline.hs @@ -24,3 +24,9 @@ toStripped checkId = mapReader fromEntryPoint . applyTransformations (toStripped -- | Perform transformations on stored Core necessary before the translation to VampIR toVampIR :: (Members '[Error JuvixError, Reader EntryPoint] r) => Module -> Sem r Module toVampIR = mapReader fromEntryPoint . applyTransformations toVampIRTransformations + +extraAnomaTransformations :: [TransformationId] +extraAnomaTransformations = [ComputeCaseANF] + +applyExtraTransformations :: (Members '[Error JuvixError, Reader EntryPoint] r) => [TransformationId] -> Module -> Sem r Module +applyExtraTransformations transforms = mapReader fromEntryPoint . applyTransformations transforms diff --git a/src/Juvix/Compiler/Core/Transformation.hs b/src/Juvix/Compiler/Core/Transformation.hs index da7705c30..03d6d268c 100644 --- a/src/Juvix/Compiler/Core/Transformation.hs +++ b/src/Juvix/Compiler/Core/Transformation.hs @@ -19,6 +19,7 @@ import Juvix.Compiler.Core.Transformation.Check.Exec import Juvix.Compiler.Core.Transformation.Check.Rust import Juvix.Compiler.Core.Transformation.Check.VampIR import Juvix.Compiler.Core.Transformation.CombineInfoTables (combineInfoTables) +import Juvix.Compiler.Core.Transformation.ComputeCaseANF import Juvix.Compiler.Core.Transformation.ComputeTypeInfo import Juvix.Compiler.Core.Transformation.ConvertBuiltinTypes import Juvix.Compiler.Core.Transformation.DisambiguateNames @@ -72,6 +73,7 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts IntToPrimInt -> return . intToPrimInt ConvertBuiltinTypes -> return . convertBuiltinTypes ComputeTypeInfo -> return . computeTypeInfo + ComputeCaseANF -> return . computeCaseANF UnrollRecursion -> unrollRecursion MatchToCase -> mapError (JuvixError @CoreError) . matchToCase EtaExpandApps -> return . etaExpansionApps diff --git a/src/Juvix/Compiler/Core/Transformation/ComputeCaseANF.hs b/src/Juvix/Compiler/Core/Transformation/ComputeCaseANF.hs new file mode 100644 index 000000000..fae9cba46 --- /dev/null +++ b/src/Juvix/Compiler/Core/Transformation/ComputeCaseANF.hs @@ -0,0 +1,62 @@ +module Juvix.Compiler.Core.Transformation.ComputeCaseANF (computeCaseANF) where + +-- A transformation which lifts out non-immediate values matched on in case +-- expressions by introducing let-bindings for them. In essence, this is a +-- partial ANF transformation for case expressions only. +-- +-- For example, transforms +-- ``` +-- case f x of { c y := y + x; d y := y } +-- ``` +-- to +-- ``` +-- let z := f x in case z of { c y := y + x; d y := y } +-- ``` +-- This transformation is needed for the Nockma backend. + +import Juvix.Compiler.Core.Data.BinderList qualified as BL +import Juvix.Compiler.Core.Extra +import Juvix.Compiler.Core.Info.TypeInfo qualified as Info +import Juvix.Compiler.Core.Transformation.Base +import Juvix.Compiler.Core.Transformation.ComputeTypeInfo (computeNodeTypeInfo) + +convertNode :: Module -> Node -> Node +convertNode md = Info.removeTypeInfo . rmapL go . computeNodeTypeInfo md + where + go :: ([BinderChange] -> Node -> Node) -> BinderList Binder -> Node -> Node + go recur bl node = case node of + NCase Case {..} + | not (isImmediate md _caseValue) -> + mkLet _caseInfo b val' $ + NCase + Case + { _caseValue = mkVar' 0, + _caseBranches = map goCaseBranch _caseBranches, + _caseDefault = fmap (go (recur . (BCAdd 1 :)) bl) _caseDefault, + _caseInfo, + _caseInductive + } + where + val' = go recur bl _caseValue + b = Binder "case_value" Nothing ty + ty = Info.getNodeType _caseValue + + goCaseBranch :: CaseBranch -> CaseBranch + goCaseBranch CaseBranch {..} = + CaseBranch + { _caseBranchBody = + go + (recur . ((BCAdd 1 : map BCKeep _caseBranchBinders) ++)) + (BL.prependRev _caseBranchBinders bl) + _caseBranchBody, + _caseBranchTag, + _caseBranchInfo, + _caseBranchBindersNum, + _caseBranchBinders + } + _ -> + recur [] node + +computeCaseANF :: Module -> Module +computeCaseANF md = + mapAllNodes (convertNode md) md diff --git a/src/Juvix/Compiler/Pipeline.hs b/src/Juvix/Compiler/Pipeline.hs index 2f6ac71cf..ad95a0a79 100644 --- a/src/Juvix/Compiler/Pipeline.hs +++ b/src/Juvix/Compiler/Pipeline.hs @@ -165,7 +165,7 @@ upToTree :: (Members '[HighlightBuilder, Reader Parser.ParserResult, Reader EntryPoint, Reader Store.ModuleTable, Files, NameIdGen, Error JuvixError] r) => Sem r Tree.InfoTable upToTree = - upToStoredCore >>= \Core.CoreResult {..} -> storedCoreToTree Core.IdentityTrans _coreResultModule + upToStoredCore >>= \Core.CoreResult {..} -> storedCoreToTree Core.IdentityTrans [] _coreResultModule upToAsm :: (Members '[HighlightBuilder, Reader Parser.ParserResult, Reader EntryPoint, Reader Store.ModuleTable, Files, NameIdGen, Error JuvixError] r) => @@ -226,17 +226,21 @@ upToCoreTypecheck = do storedCoreToTree :: (Members '[Error JuvixError, Reader EntryPoint] r) => Core.TransformationId -> + [Core.TransformationId] -> Core.Module -> Sem r Tree.InfoTable -storedCoreToTree checkId md = do +storedCoreToTree checkId extraTransforms md = do fsize <- asks (^. entryPointFieldSize) - Tree.fromCore . Stripped.fromCore fsize . Core.computeCombinedInfoTable <$> Core.toStripped checkId md + Tree.fromCore + . Stripped.fromCore fsize + . Core.computeCombinedInfoTable + <$> (Core.toStripped checkId md >>= Core.applyExtraTransformations extraTransforms) storedCoreToAnoma :: (Members '[Error JuvixError, Reader EntryPoint] r) => Core.Module -> Sem r NockmaTree.AnomaResult -storedCoreToAnoma = storedCoreToTree Core.CheckAnoma >=> treeToAnoma +storedCoreToAnoma = storedCoreToTree Core.CheckAnoma Core.extraAnomaTransformations >=> treeToAnoma storedCoreToAsm :: (Members '[Error JuvixError, Reader EntryPoint] r) => Core.Module -> Sem r Asm.InfoTable -storedCoreToAsm = storedCoreToTree Core.CheckExec >=> treeToAsm +storedCoreToAsm = storedCoreToTree Core.CheckExec [] >=> treeToAsm storedCoreToReg :: (Members '[Error JuvixError, Reader EntryPoint] r) => Core.Module -> Sem r Reg.InfoTable storedCoreToReg = storedCoreToAsm >=> asmToReg @@ -245,13 +249,13 @@ storedCoreToMiniC :: (Members '[Error JuvixError, Reader EntryPoint] r) => Core. storedCoreToMiniC = storedCoreToAsm >=> asmToMiniC storedCoreToRust :: (Members '[Error JuvixError, Reader EntryPoint] r) => Core.Module -> Sem r Rust.Result -storedCoreToRust = storedCoreToTree Core.CheckRust >=> treeToReg >=> regToRust +storedCoreToRust = storedCoreToTree Core.CheckRust [] >=> treeToReg >=> regToRust storedCoreToRiscZeroRust :: (Members '[Error JuvixError, Reader EntryPoint] r) => Core.Module -> Sem r Rust.Result -storedCoreToRiscZeroRust = storedCoreToTree Core.CheckRust >=> treeToReg >=> regToRiscZeroRust +storedCoreToRiscZeroRust = storedCoreToTree Core.CheckRust [] >=> treeToReg >=> regToRiscZeroRust storedCoreToCasm :: (Members '[Error JuvixError, Reader EntryPoint] r) => Core.Module -> Sem r Casm.Result -storedCoreToCasm = local (set entryPointFieldSize cairoFieldSize) . storedCoreToTree Core.CheckCairo >=> treeToCasm +storedCoreToCasm = local (set entryPointFieldSize cairoFieldSize) . storedCoreToTree Core.CheckCairo [] >=> treeToCasm storedCoreToCairo :: (Members '[Error JuvixError, Reader EntryPoint] r) => Core.Module -> Sem r Cairo.Result storedCoreToCairo = storedCoreToCasm >=> casmToCairo @@ -263,8 +267,8 @@ storedCoreToVampIR = Core.toVampIR >=> VampIR.fromCore . Core.computeCombinedInf -- Workflows from Core -------------------------------------------------------------------------------- -coreToTree :: (Members '[Error JuvixError, Reader EntryPoint] r) => Core.TransformationId -> Core.Module -> Sem r Tree.InfoTable -coreToTree checkId = Core.toStored >=> storedCoreToTree checkId +coreToTree :: (Members '[Error JuvixError, Reader EntryPoint] r) => Core.TransformationId -> [Core.TransformationId] -> Core.Module -> Sem r Tree.InfoTable +coreToTree checkId extraTransforms = Core.toStored >=> storedCoreToTree checkId extraTransforms coreToAsm :: (Members '[Error JuvixError, Reader EntryPoint] r) => Core.Module -> Sem r Asm.InfoTable coreToAsm = Core.toStored >=> storedCoreToAsm @@ -279,7 +283,7 @@ coreToCairo :: (Members '[Error JuvixError, Reader EntryPoint] r) => Core.Module coreToCairo = Core.toStored >=> storedCoreToCairo coreToAnoma :: (Members '[Error JuvixError, Reader EntryPoint] r) => Core.Module -> Sem r NockmaTree.AnomaResult -coreToAnoma = coreToTree Core.CheckAnoma >=> treeToAnoma +coreToAnoma = coreToTree Core.CheckAnoma Core.extraAnomaTransformations >=> treeToAnoma coreToRust :: (Members '[Error JuvixError, Reader EntryPoint] r) => Core.Module -> Sem r Rust.Result coreToRust = Core.toStored >=> storedCoreToRust