From 6d2f6c673cfba54ecd9f87a492402b103e5b1a4e Mon Sep 17 00:00:00 2001 From: Pavel Marek Date: Thu, 21 Nov 2024 18:36:00 +0100 Subject: [PATCH] PassManager.runInline also chains mini passes. Make the algorithm for running passes generic --- .../enso/compiler/pass/IRProcessingPass.java | 7 +- .../org/enso/compiler/pass/PassManager.scala | 171 ++++++++++-------- 2 files changed, 104 insertions(+), 74 deletions(-) diff --git a/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/IRProcessingPass.java b/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/IRProcessingPass.java index 2d38422d48..8de4d16b8c 100644 --- a/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/IRProcessingPass.java +++ b/engine/runtime-compiler/src/main/java/org/enso/compiler/pass/IRProcessingPass.java @@ -11,6 +11,11 @@ public interface IRProcessingPass extends ProcessingPass { /** The passes that this pass depends _directly_ on to run. */ public Seq precursorPasses(); - /** The passes that are invalidated by running this pass. */ + /** + * The passes that are invalidated by running this pass. + * + *

If {@code P1} invalidates {@code P2}, and {@code P1} is a precursor of {@code P2}, then + * {@code P1} must finish before {@code P2} starts. + */ public Seq invalidatedPasses(); } diff --git a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/PassManager.scala b/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/PassManager.scala index cfd2ffde93..168b9c7765 100644 --- a/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/PassManager.scala +++ b/engine/runtime-compiler/src/main/scala/org/enso/compiler/pass/PassManager.scala @@ -4,9 +4,11 @@ import org.enso.common.{Asserts, CompilationStage} import org.slf4j.LoggerFactory import org.enso.compiler.context.{InlineContext, ModuleContext} import org.enso.compiler.core.ir.{Expression, Module} -import org.enso.compiler.core.CompilerError +import org.enso.compiler.core.{CompilerError, IR} import org.enso.compiler.pass.analyse.BindingAnalysis +import scala.collection.mutable.ListBuffer + // TODO [AA] In the future, the pass ordering should be _computed_ from the list // of available passes, rather than just verified. @@ -65,6 +67,7 @@ class PassManager( * @param moduleContext the module context in which the passes are executed * @return the result of executing `passGroup` on `ir` */ + // TODO: Remove this method def runPassesOnModule( ir: Module, moduleContext: ModuleContext @@ -94,77 +97,29 @@ class PassManager( throw new CompilerError("Cannot run an unvalidated pass group.") } - val newContext = - moduleContext.copy(passConfiguration = Some(passConfiguration)) - - val passesWithIndex = passGroup.passes.zipWithIndex - logger.debug( "runPassesOnModule[{}@{}]", moduleContext.getName(), moduleContext.module.getCompilationStage() ) - var pendingMiniPasses: List[MiniPassFactory] = List() - def flushMiniPass(in: Module): Module = { - if (pendingMiniPasses.nonEmpty) { - val miniPasses = pendingMiniPasses.map(factory => - factory.createForModuleCompilation(newContext) - ) - val combinedPass = miniPasses.fold(null)(MiniIRPass.combine) - pendingMiniPasses = List() - if (combinedPass != null) { - logger.trace(" flushing pending mini pass: {}", combinedPass) - MiniIRPass.compile(classOf[Module], in, combinedPass) - } else { - in - } - } else { - in - } - } - val res = passesWithIndex.foldLeft(ir) { - case (intermediateIR, (pass, index)) => { - pass match { - case miniFactory: MiniPassFactory => - logger.trace( - " mini collected: {}", - pass - ) - val combiningPreventedBy = pendingMiniPasses.find { p => - p.invalidatedPasses.contains(miniFactory) - } - val irForRemainingMiniPasses = if (combiningPreventedBy.isDefined) { - logger.trace( - " pass {} forces flush before {}", - combiningPreventedBy.orNull, - miniFactory - ) - flushMiniPass(intermediateIR) - } else { - intermediateIR - } - pendingMiniPasses = pendingMiniPasses.appended(miniFactory) - irForRemainingMiniPasses - case megaPass: IRPass => - // TODO [AA, MK] This is a possible race condition. - passConfiguration - .get(megaPass) - .foreach(c => - c.shouldWriteToContext = isLastRunOf(index, megaPass, passGroup) - ) - val flushedIR = flushMiniPass(intermediateIR) - logger.trace( - " mega running: {}", - megaPass - ) - megaPass.runModule(flushedIR, newContext) - } - } - } - flushMiniPass(res) + + val newContext = + moduleContext.copy(passConfiguration = Some(passConfiguration)) + + runPasses[Module, ModuleContext]( + ir, + newContext, + passGroup, + createMiniPass = + (factory, ctx) => factory.createForModuleCompilation(ctx), + miniPassCompile = (miniPass, ir) => + MiniIRPass.compile[Module](classOf[Module], ir, miniPass), + megaPassCompile = (megaPass, ir, ctx) => megaPass.runModule(ir, ctx) + ) } /** Executes all passes on the [[Expression]]. + * TODO: Remove this method? * * @param ir the expression to execute the compiler passes on * @param inlineContext the inline context in which the passes are executed @@ -198,15 +153,80 @@ class PassManager( val newContext = inlineContext.copy(passConfiguration = Some(passConfiguration)) + runPasses[Expression, InlineContext]( + ir, + newContext, + passGroup, + createMiniPass = + (factory, ctx) => factory.createForInlineCompilation(ctx), + miniPassCompile = (miniPass, ir) => + MiniIRPass.compile[Expression](classOf[Expression], ir, miniPass), + megaPassCompile = (megaPass, ir, ctx) => megaPass.runExpression(ir, ctx) + ) + } + + /** Runs all the passes in the given `passGroup` on `ir` with `context`. + * @param createMiniPass Function that creates a minipass. + * @param miniPassCompile Function that compiles IR with mini pass. + * @param megaPassCompile Function that compiles IR with mega pass. + * @tparam IRType Type of the [[IR]] that is being compiled. + * @tparam ContextType Type of the context for the compilation. + * Either [[ModuleContext]] or [[InlineContext]] + * @return Compiled IR. Might be the same reference as `ir` if no compilation was done. + */ + private def runPasses[IRType <: IR, ContextType]( + ir: IRType, + context: ContextType, + passGroup: PassGroup, + createMiniPass: (MiniPassFactory, ContextType) => MiniIRPass, + miniPassCompile: (MiniIRPass, IRType) => IRType, + megaPassCompile: (IRPass, IRType, ContextType) => IRType + ): IRType = { + val pendingMiniPasses: ListBuffer[MiniPassFactory] = ListBuffer() + + def flushMiniPasses(in: IRType): IRType = { + if (pendingMiniPasses.nonEmpty) { + val miniPasses = + pendingMiniPasses.map(factory => createMiniPass(factory, context)) + val combinedPass = miniPasses.fold(null)(MiniIRPass.combine) + pendingMiniPasses.clear() + if (combinedPass != null) { + logger.trace(" flushing pending mini pass: {}", combinedPass) + miniPassCompile(combinedPass, in) + } else { + in + } + } else { + in + } + } + val passesWithIndex = passGroup.passes.zipWithIndex - - passesWithIndex.foldLeft(ir) { - case (intermediateIR, (pass, index)) => { - + val res = passesWithIndex.foldLeft(ir) { + case (intermediateIR, (pass, index)) => pass match { case miniFactory: MiniPassFactory => - val miniPass = miniFactory.createForInlineCompilation(newContext) - MiniIRPass.compile(classOf[Expression], intermediateIR, miniPass) + logger.trace( + " mini collected: {}", + pass + ) + val combiningPreventedByOpt = pendingMiniPasses.find { p => + p.invalidatedPasses.contains(miniFactory) + } + val irForRemainingMiniPasses = combiningPreventedByOpt match { + case Some(combiningPreventedBy) => + logger.trace( + " pass {} forces flush before (invalidates) {}", + combiningPreventedBy, + miniFactory + ) + flushMiniPasses(intermediateIR) + case None => + intermediateIR + } + pendingMiniPasses.addOne(miniFactory) + irForRemainingMiniPasses + case megaPass: IRPass => // TODO [AA, MK] This is a possible race condition. passConfiguration @@ -214,11 +234,16 @@ class PassManager( .foreach(c => c.shouldWriteToContext = isLastRunOf(index, megaPass, passGroup) ) - megaPass.runExpression(intermediateIR, newContext) + val flushedIR = flushMiniPasses(intermediateIR) + logger.trace( + " mega running: {}", + megaPass + ) + megaPassCompile(megaPass, flushedIR, context) } - - } } + + flushMiniPasses(res) } /** Determines whether the run at index `indexOfPassInGroup` is the last run