PassManager.runInline also chains mini passes.

Make the algorithm for running passes generic
This commit is contained in:
Pavel Marek 2024-11-21 18:36:00 +01:00
parent 0a4ac1f922
commit 6d2f6c673c
2 changed files with 104 additions and 74 deletions

View File

@ -11,6 +11,11 @@ public interface IRProcessingPass extends ProcessingPass {
/** The passes that this pass depends _directly_ on to run. */ /** The passes that this pass depends _directly_ on to run. */
public Seq<? extends IRProcessingPass> precursorPasses(); public Seq<? extends IRProcessingPass> precursorPasses();
/** The passes that are invalidated by running this pass. */ /**
* The passes that are invalidated by running this pass.
*
* <p>If {@code P1} invalidates {@code P2}, and {@code P1} is a precursor of {@code P2}, then
* {@code P1} must finish before {@code P2} starts.
*/
public Seq<? extends IRProcessingPass> invalidatedPasses(); public Seq<? extends IRProcessingPass> invalidatedPasses();
} }

View File

@ -4,9 +4,11 @@ import org.enso.common.{Asserts, CompilationStage}
import org.slf4j.LoggerFactory import org.slf4j.LoggerFactory
import org.enso.compiler.context.{InlineContext, ModuleContext} import org.enso.compiler.context.{InlineContext, ModuleContext}
import org.enso.compiler.core.ir.{Expression, Module} import org.enso.compiler.core.ir.{Expression, Module}
import org.enso.compiler.core.CompilerError import org.enso.compiler.core.{CompilerError, IR}
import org.enso.compiler.pass.analyse.BindingAnalysis import org.enso.compiler.pass.analyse.BindingAnalysis
import scala.collection.mutable.ListBuffer
// TODO [AA] In the future, the pass ordering should be _computed_ from the list // TODO [AA] In the future, the pass ordering should be _computed_ from the list
// of available passes, rather than just verified. // of available passes, rather than just verified.
@ -65,6 +67,7 @@ class PassManager(
* @param moduleContext the module context in which the passes are executed * @param moduleContext the module context in which the passes are executed
* @return the result of executing `passGroup` on `ir` * @return the result of executing `passGroup` on `ir`
*/ */
// TODO: Remove this method
def runPassesOnModule( def runPassesOnModule(
ir: Module, ir: Module,
moduleContext: ModuleContext moduleContext: ModuleContext
@ -94,77 +97,29 @@ class PassManager(
throw new CompilerError("Cannot run an unvalidated pass group.") throw new CompilerError("Cannot run an unvalidated pass group.")
} }
val newContext =
moduleContext.copy(passConfiguration = Some(passConfiguration))
val passesWithIndex = passGroup.passes.zipWithIndex
logger.debug( logger.debug(
"runPassesOnModule[{}@{}]", "runPassesOnModule[{}@{}]",
moduleContext.getName(), moduleContext.getName(),
moduleContext.module.getCompilationStage() moduleContext.module.getCompilationStage()
) )
var pendingMiniPasses: List[MiniPassFactory] = List()
def flushMiniPass(in: Module): Module = { val newContext =
if (pendingMiniPasses.nonEmpty) { moduleContext.copy(passConfiguration = Some(passConfiguration))
val miniPasses = pendingMiniPasses.map(factory =>
factory.createForModuleCompilation(newContext) runPasses[Module, ModuleContext](
ir,
newContext,
passGroup,
createMiniPass =
(factory, ctx) => factory.createForModuleCompilation(ctx),
miniPassCompile = (miniPass, ir) =>
MiniIRPass.compile[Module](classOf[Module], ir, miniPass),
megaPassCompile = (megaPass, ir, ctx) => megaPass.runModule(ir, ctx)
) )
val combinedPass = miniPasses.fold(null)(MiniIRPass.combine)
pendingMiniPasses = List()
if (combinedPass != null) {
logger.trace(" flushing pending mini pass: {}", combinedPass)
MiniIRPass.compile(classOf[Module], in, combinedPass)
} else {
in
}
} else {
in
}
}
val res = passesWithIndex.foldLeft(ir) {
case (intermediateIR, (pass, index)) => {
pass match {
case miniFactory: MiniPassFactory =>
logger.trace(
" mini collected: {}",
pass
)
val combiningPreventedBy = pendingMiniPasses.find { p =>
p.invalidatedPasses.contains(miniFactory)
}
val irForRemainingMiniPasses = if (combiningPreventedBy.isDefined) {
logger.trace(
" pass {} forces flush before {}",
combiningPreventedBy.orNull,
miniFactory
)
flushMiniPass(intermediateIR)
} else {
intermediateIR
}
pendingMiniPasses = pendingMiniPasses.appended(miniFactory)
irForRemainingMiniPasses
case megaPass: IRPass =>
// TODO [AA, MK] This is a possible race condition.
passConfiguration
.get(megaPass)
.foreach(c =>
c.shouldWriteToContext = isLastRunOf(index, megaPass, passGroup)
)
val flushedIR = flushMiniPass(intermediateIR)
logger.trace(
" mega running: {}",
megaPass
)
megaPass.runModule(flushedIR, newContext)
}
}
}
flushMiniPass(res)
} }
/** Executes all passes on the [[Expression]]. /** Executes all passes on the [[Expression]].
* TODO: Remove this method?
* *
* @param ir the expression to execute the compiler passes on * @param ir the expression to execute the compiler passes on
* @param inlineContext the inline context in which the passes are executed * @param inlineContext the inline context in which the passes are executed
@ -198,15 +153,80 @@ class PassManager(
val newContext = val newContext =
inlineContext.copy(passConfiguration = Some(passConfiguration)) inlineContext.copy(passConfiguration = Some(passConfiguration))
runPasses[Expression, InlineContext](
ir,
newContext,
passGroup,
createMiniPass =
(factory, ctx) => factory.createForInlineCompilation(ctx),
miniPassCompile = (miniPass, ir) =>
MiniIRPass.compile[Expression](classOf[Expression], ir, miniPass),
megaPassCompile = (megaPass, ir, ctx) => megaPass.runExpression(ir, ctx)
)
}
/** Runs all the passes in the given `passGroup` on `ir` with `context`.
* @param createMiniPass Function that creates a minipass.
* @param miniPassCompile Function that compiles IR with mini pass.
* @param megaPassCompile Function that compiles IR with mega pass.
* @tparam IRType Type of the [[IR]] that is being compiled.
* @tparam ContextType Type of the context for the compilation.
* Either [[ModuleContext]] or [[InlineContext]]
* @return Compiled IR. Might be the same reference as `ir` if no compilation was done.
*/
private def runPasses[IRType <: IR, ContextType](
ir: IRType,
context: ContextType,
passGroup: PassGroup,
createMiniPass: (MiniPassFactory, ContextType) => MiniIRPass,
miniPassCompile: (MiniIRPass, IRType) => IRType,
megaPassCompile: (IRPass, IRType, ContextType) => IRType
): IRType = {
val pendingMiniPasses: ListBuffer[MiniPassFactory] = ListBuffer()
def flushMiniPasses(in: IRType): IRType = {
if (pendingMiniPasses.nonEmpty) {
val miniPasses =
pendingMiniPasses.map(factory => createMiniPass(factory, context))
val combinedPass = miniPasses.fold(null)(MiniIRPass.combine)
pendingMiniPasses.clear()
if (combinedPass != null) {
logger.trace(" flushing pending mini pass: {}", combinedPass)
miniPassCompile(combinedPass, in)
} else {
in
}
} else {
in
}
}
val passesWithIndex = passGroup.passes.zipWithIndex val passesWithIndex = passGroup.passes.zipWithIndex
val res = passesWithIndex.foldLeft(ir) {
passesWithIndex.foldLeft(ir) { case (intermediateIR, (pass, index)) =>
case (intermediateIR, (pass, index)) => {
pass match { pass match {
case miniFactory: MiniPassFactory => case miniFactory: MiniPassFactory =>
val miniPass = miniFactory.createForInlineCompilation(newContext) logger.trace(
MiniIRPass.compile(classOf[Expression], intermediateIR, miniPass) " mini collected: {}",
pass
)
val combiningPreventedByOpt = pendingMiniPasses.find { p =>
p.invalidatedPasses.contains(miniFactory)
}
val irForRemainingMiniPasses = combiningPreventedByOpt match {
case Some(combiningPreventedBy) =>
logger.trace(
" pass {} forces flush before (invalidates) {}",
combiningPreventedBy,
miniFactory
)
flushMiniPasses(intermediateIR)
case None =>
intermediateIR
}
pendingMiniPasses.addOne(miniFactory)
irForRemainingMiniPasses
case megaPass: IRPass => case megaPass: IRPass =>
// TODO [AA, MK] This is a possible race condition. // TODO [AA, MK] This is a possible race condition.
passConfiguration passConfiguration
@ -214,11 +234,16 @@ class PassManager(
.foreach(c => .foreach(c =>
c.shouldWriteToContext = isLastRunOf(index, megaPass, passGroup) c.shouldWriteToContext = isLastRunOf(index, megaPass, passGroup)
) )
megaPass.runExpression(intermediateIR, newContext) val flushedIR = flushMiniPasses(intermediateIR)
logger.trace(
" mega running: {}",
megaPass
)
megaPassCompile(megaPass, flushedIR, context)
}
} }
} flushMiniPasses(res)
}
} }
/** Determines whether the run at index `indexOfPassInGroup` is the last run /** Determines whether the run at index `indexOfPassInGroup` is the last run