From 1dfcf1cafc939b745d86e19cea4a982db79d9f6c Mon Sep 17 00:00:00 2001 From: Jaroslav Tulach Date: Wed, 14 Dec 2022 05:00:38 +0100 Subject: [PATCH] AvoidIdInstrumentationTagTest to control which nodes are instrumentable (#3977) Fighting with _too many messages being delivered_ I wrote a test that dumps information about `AvoidIdInstrumentationTag` - every node that has `AvoidIdInstrumentationTag` is excluded from the instrumentation. However, when I look at the output for ``` from Standard.Base import all import Standard.Visualization run n = 0.up_to n . map i-> 1.noise * i ``` I see that `1.noise` didn't have the tag. Now there is [AvoidIdInstrumentationTagTest.java](https://github.com/enso-org/enso/pull/3973/files#diff-32cd9240bda2bfe0e5904695ced008daba86fefb3d137ac401997f4265fa50eb) which can be used to collect all programs where _too many messages is being delivered_. Just add a program, identify _isLambda_ and verify all nodes are properly tagged. --- .../AvoidIdInstrumentationTagTest.java | 119 ++++++++++++++++++ .../FunctionCallInstrumentationNode.java | 5 + .../node/callable/function/StatementNode.java | 5 + .../enso/compiler/codegen/IrToTruffle.scala | 52 +++++--- 4 files changed, 164 insertions(+), 17 deletions(-) create mode 100644 engine/runtime-with-instruments/src/test/java/org/enso/interpreter/test/instrument/AvoidIdInstrumentationTagTest.java diff --git a/engine/runtime-with-instruments/src/test/java/org/enso/interpreter/test/instrument/AvoidIdInstrumentationTagTest.java b/engine/runtime-with-instruments/src/test/java/org/enso/interpreter/test/instrument/AvoidIdInstrumentationTagTest.java new file mode 100644 index 0000000000..8ec1f9eeb9 --- /dev/null +++ b/engine/runtime-with-instruments/src/test/java/org/enso/interpreter/test/instrument/AvoidIdInstrumentationTagTest.java @@ -0,0 +1,119 @@ +package org.enso.interpreter.test.instrument; +import com.oracle.truffle.api.instrumentation.InstrumentableNode; +import com.oracle.truffle.api.instrumentation.StandardTags; +import com.oracle.truffle.api.nodes.Node; +import com.oracle.truffle.api.source.SourceSection; +import java.io.OutputStream; +import java.nio.file.Paths; +import java.util.List; +import java.util.Map; +import java.util.function.Predicate; +import org.enso.interpreter.node.ClosureRootNode; +import org.enso.interpreter.runtime.tag.AvoidIdInstrumentationTag; +import org.enso.interpreter.runtime.tag.IdentifiedTag; +import org.enso.interpreter.test.NodeCountingTestInstrument; +import org.enso.polyglot.RuntimeOptions; +import org.graalvm.polyglot.Context; +import org.graalvm.polyglot.Engine; +import org.graalvm.polyglot.Language; +import org.junit.After; +import org.junit.Assert; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; +import org.junit.Before; +import org.junit.Test; + +public class AvoidIdInstrumentationTagTest { + + private Engine engine; + private Context context; + private NodeCountingTestInstrument nodes; + + @Before + public void initContext() { + engine = Engine.newBuilder() + .allowExperimentalOptions(true) + .option( + RuntimeOptions.LANGUAGE_HOME_OVERRIDE, + Paths.get("../../distribution/component").toFile().getAbsolutePath() + ) + .logHandler(OutputStream.nullOutputStream()) + .build(); + + context = Context.newBuilder() + .engine(engine) + .allowExperimentalOptions(true) + .allowIO(true) + .allowAllAccess(true) + .build(); + + Map langs = engine.getLanguages(); + Assert.assertNotNull("Enso found: " + langs, langs.get("enso")); + + nodes = engine.getInstruments().get(NodeCountingTestInstrument.INSTRUMENT_ID).lookup(NodeCountingTestInstrument.class); + nodes.enable(); + } + + @After + public void disposeContext() { + context.close(); + engine.close(); + } + + @Test + public void avoidIdInstrumentationInLambdaMapFunctionWithNoise() { + var code = """ + from Standard.Base import all + import Standard.Visualization + + run n = 0.up_to n . map i-> 1.noise * i + """; + + var module = context.eval("enso", code); + var run = module.invokeMember("eval_expression", "run"); + var res = run.execute(10000); + assertEquals("Array of the requested size computed", 10000, res.getArraySize()); + + Predicate isLambda = (ss) -> { + var st = ss.getCharacters().toString(); + return st.contains("noise") && !st.contains("map"); + }; + + assertAvoidIdInstrumentationTag(isLambda); + } + + private void assertAvoidIdInstrumentationTag(Predicate isLambda) { + var found = nodes.assertNewNodes("Give me nodes", 0, 10000); + var err = new StringBuilder(); + var missingTagInLambda = false; + for (var nn : found.values()) { + for (var n : nn) { + var ss = n.getSourceSection(); + if (ss == null) { + continue; + } + if (isLambda.test(ss)) { + err.append("\n").append("code: ").append(ss.getCharacters()).append(" for node ").append(n.getClass().getName()); + if (n instanceof InstrumentableNode in) { + final boolean hasAvoidIdInstrumentationTag = in.hasTag(AvoidIdInstrumentationTag.class); + if (!hasAvoidIdInstrumentationTag) { + missingTagInLambda = true; + } + + err.append("\n").append(" AvoidIdInstrumentationTag: ").append(hasAvoidIdInstrumentationTag); + err.append("\n").append(" IdentifiedTag: ").append(in.hasTag(IdentifiedTag.class)); + err.append("\n").append(" ExpressionTag: ").append(in.hasTag(StandardTags.ExpressionTag.class)); + err.append("\n").append(" RootNode: ").append(n.getRootNode()); + if (n.getRootNode() instanceof ClosureRootNode crn) { + err.append("\n").append(" ClosureRootNode.subject to instr: ").append(crn.isSubjectToInstrumentation()); + err.append("\n").append(" ClosureRootNode.used in bindings: ").append(crn.isUsedInBinding()); + } + } + } + } + } + if (missingTagInLambda) { + fail(err.toString()); + } + } +} diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/callable/FunctionCallInstrumentationNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/callable/FunctionCallInstrumentationNode.java index 863bf49377..b7aaf3e5f5 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/node/callable/FunctionCallInstrumentationNode.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/callable/FunctionCallInstrumentationNode.java @@ -21,6 +21,8 @@ import org.enso.interpreter.runtime.tag.IdentifiedTag; import java.util.Arrays; import java.util.UUID; +import org.enso.interpreter.node.ClosureRootNode; +import org.enso.interpreter.runtime.tag.AvoidIdInstrumentationTag; /** * A node used for instrumenting function calls. It does nothing useful from the language @@ -151,6 +153,9 @@ public class FunctionCallInstrumentationNode extends Node implements Instrumenta */ @Override public boolean hasTag(Class tag) { + if (AvoidIdInstrumentationTag.class == tag) { + return getRootNode() instanceof ClosureRootNode c && !c.isSubjectToInstrumentation(); + } return tag == StandardTags.CallTag.class || (tag == IdentifiedTag.class && id != null); } diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/callable/function/StatementNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/callable/function/StatementNode.java index 2329abeea0..aab00d4d56 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/node/callable/function/StatementNode.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/callable/function/StatementNode.java @@ -4,7 +4,9 @@ import com.oracle.truffle.api.frame.VirtualFrame; import com.oracle.truffle.api.instrumentation.StandardTags; import com.oracle.truffle.api.instrumentation.Tag; import com.oracle.truffle.api.source.SourceSection; +import org.enso.interpreter.node.ClosureRootNode; import org.enso.interpreter.node.ExpressionNode; +import org.enso.interpreter.runtime.tag.AvoidIdInstrumentationTag; /** * Node tagged with {@link StandardTags.StatementTag}. Inserted by {@link BlockNode} into the AST @@ -42,6 +44,9 @@ final class StatementNode extends ExpressionNode { @Override public boolean hasTag(Class tag) { + if (AvoidIdInstrumentationTag.class == tag) { + return getRootNode() instanceof ClosureRootNode c && !c.isSubjectToInstrumentation(); + } return StandardTags.StatementTag.class == tag; } } diff --git a/engine/runtime/src/main/scala/org/enso/compiler/codegen/IrToTruffle.scala b/engine/runtime/src/main/scala/org/enso/compiler/codegen/IrToTruffle.scala index 3765f3ee64..6f824c37d8 100644 --- a/engine/runtime/src/main/scala/org/enso/compiler/codegen/IrToTruffle.scala +++ b/engine/runtime/src/main/scala/org/enso/compiler/codegen/IrToTruffle.scala @@ -391,7 +391,8 @@ class IrToTruffle( new expressionProcessor.BuildFunctionBody( fn.arguments, fn.body, - effectContext + effectContext, + true ) val rootNode = MethodRootNode.build( language, @@ -472,7 +473,8 @@ class IrToTruffle( new expressionProcessor.BuildFunctionBody( fn.arguments, fn.body, - None + None, + true ) val rootNode = MethodRootNode.build( language, @@ -763,13 +765,18 @@ class IrToTruffle( * @param ir the IR to generate code for * @return a truffle expression that represents the same program as `ir` */ - def run(ir: IR.Expression): RuntimeExpression = run(ir, false) + def run(ir: IR.Expression): RuntimeExpression = run(ir, false, true) - private def run(ir: IR.Expression, binding: Boolean): RuntimeExpression = { + private def run( + ir: IR.Expression, + binding: Boolean, + subjectToInstrumentation: Boolean + ): RuntimeExpression = { val runtimeExpression = ir match { - case block: IR.Expression.Block => processBlock(block) - case literal: IR.Literal => processLiteral(literal) - case app: IR.Application => processApplication(app) + case block: IR.Expression.Block => processBlock(block) + case literal: IR.Literal => processLiteral(literal) + case app: IR.Application => + processApplication(app, subjectToInstrumentation) case name: IR.Name => processName(name) case function: IR.Function => processFunction(function, binding) case binding: IR.Expression.Binding => processBinding(binding) @@ -1215,7 +1222,7 @@ class IrToTruffle( val slotIdx = scope.getVarSlotIdx(occInfo.id) setLocation( - AssignmentNode.build(this.run(binding.expression, true), slotIdx), + AssignmentNode.build(this.run(binding.expression, true, true), slotIdx), binding.location ) } @@ -1464,7 +1471,8 @@ class IrToTruffle( class BuildFunctionBody( val arguments: List[IR.DefinitionArgument], val body: IR.Expression, - val effectContext: Option[String] + val effectContext: Option[String], + val subjectToInstrumentation: Boolean ) { private val argFactory = new DefinitionArgumentProcessor(scopeName, scope) private lazy val slots = computeSlots() @@ -1484,7 +1492,8 @@ class IrToTruffle( arguments.map(_.name.name), argSlotIdxs ) - case _ => ExpressionProcessor.this.run(body) + case _ => + ExpressionProcessor.this.run(body, false, subjectToInstrumentation) } val block = BlockNode.build(argExpressions.toArray, bodyExpr) effectContext match { @@ -1570,7 +1579,7 @@ class IrToTruffle( location: Option[IdentifiedLocation], binding: Boolean = false ): CreateFunctionNode = { - val bodyBuilder = new BuildFunctionBody(arguments, body, None) + val bodyBuilder = new BuildFunctionBody(arguments, body, None, binding) val fnRootNode = ClosureRootNode.build( language, scope, @@ -1617,12 +1626,15 @@ class IrToTruffle( * @param application the function application to generate code for * @return the truffle nodes corresponding to `application` */ - def processApplication(application: IR.Application): RuntimeExpression = + def processApplication( + application: IR.Application, + subjectToInstrumentation: Boolean + ): RuntimeExpression = application match { case IR.Application.Prefix(fn, Nil, true, _, _, _) => run(fn) case app: IR.Application.Prefix => - processApplicationWithArgs(app) + processApplicationWithArgs(app, subjectToInstrumentation) case IR.Application.Force(expr, location, _, _) => setLocation(ForceNode.build(this.run(expr)), location) case IR.Application.Literal.Sequence(items, location, _, _) => @@ -1653,7 +1665,8 @@ class IrToTruffle( } private def processApplicationWithArgs( - application: IR.Application.Prefix + application: IR.Application.Prefix, + subjectToInstrumentation: Boolean ): RuntimeExpression = { val IR.Application.Prefix(fn, args, hasDefaultsSuspended, loc, _, _) = application @@ -1663,7 +1676,8 @@ class IrToTruffle( val callArgs = new ArrayBuffer[CallArgument]() for ((unprocessedArg, position) <- arguments.view.zipWithIndex) { - val arg = callArgFactory.run(unprocessedArg, position) + val arg = + callArgFactory.run(unprocessedArg, position, subjectToInstrumentation) callArgs.append(arg) } @@ -1714,7 +1728,11 @@ class IrToTruffle( * @return a truffle construct corresponding to the argument definition * `arg` */ - def run(arg: IR.CallArgument, position: Int): CallArgument = + def run( + arg: IR.CallArgument, + position: Int, + subjectToInstrumentation: Boolean + ): CallArgument = arg match { case IR.CallArgument.Specified( name, @@ -1765,7 +1783,7 @@ class IrToTruffle( argumentExpression, section, displayName, - true, + subjectToInstrumentation, false ) val callTarget = closureRootNode.getCallTarget