AvoidIdInstrumentationTagTest to control which nodes are instrumentable (#3977)

Fighting with _too many messages being delivered_ I wrote a test that dumps information about `AvoidIdInstrumentationTag` - every node that has `AvoidIdInstrumentationTag` is excluded from the instrumentation. However, when I look at the output for
```
from Standard.Base import all
import Standard.Visualization

run n = 0.up_to n . map i-> 1.noise * i
```
I see that `1.noise` didn't have the tag. Now there is [AvoidIdInstrumentationTagTest.java](https://github.com/enso-org/enso/pull/3973/files#diff-32cd9240bda2bfe0e5904695ced008daba86fefb3d137ac401997f4265fa50eb) which can be used to collect all programs where _too many messages is being delivered_. Just add a program, identify _isLambda_ and verify all nodes are properly tagged.
This commit is contained in:
Jaroslav Tulach 2022-12-14 05:00:38 +01:00 committed by GitHub
parent 965d1ff28b
commit 1dfcf1cafc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 164 additions and 17 deletions

View File

@ -0,0 +1,119 @@
package org.enso.interpreter.test.instrument;
import com.oracle.truffle.api.instrumentation.InstrumentableNode;
import com.oracle.truffle.api.instrumentation.StandardTags;
import com.oracle.truffle.api.nodes.Node;
import com.oracle.truffle.api.source.SourceSection;
import java.io.OutputStream;
import java.nio.file.Paths;
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;
import org.enso.interpreter.node.ClosureRootNode;
import org.enso.interpreter.runtime.tag.AvoidIdInstrumentationTag;
import org.enso.interpreter.runtime.tag.IdentifiedTag;
import org.enso.interpreter.test.NodeCountingTestInstrument;
import org.enso.polyglot.RuntimeOptions;
import org.graalvm.polyglot.Context;
import org.graalvm.polyglot.Engine;
import org.graalvm.polyglot.Language;
import org.junit.After;
import org.junit.Assert;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
import org.junit.Before;
import org.junit.Test;
public class AvoidIdInstrumentationTagTest {
private Engine engine;
private Context context;
private NodeCountingTestInstrument nodes;
@Before
public void initContext() {
engine = Engine.newBuilder()
.allowExperimentalOptions(true)
.option(
RuntimeOptions.LANGUAGE_HOME_OVERRIDE,
Paths.get("../../distribution/component").toFile().getAbsolutePath()
)
.logHandler(OutputStream.nullOutputStream())
.build();
context = Context.newBuilder()
.engine(engine)
.allowExperimentalOptions(true)
.allowIO(true)
.allowAllAccess(true)
.build();
Map<String, Language> langs = engine.getLanguages();
Assert.assertNotNull("Enso found: " + langs, langs.get("enso"));
nodes = engine.getInstruments().get(NodeCountingTestInstrument.INSTRUMENT_ID).lookup(NodeCountingTestInstrument.class);
nodes.enable();
}
@After
public void disposeContext() {
context.close();
engine.close();
}
@Test
public void avoidIdInstrumentationInLambdaMapFunctionWithNoise() {
var code = """
from Standard.Base import all
import Standard.Visualization
run n = 0.up_to n . map i-> 1.noise * i
""";
var module = context.eval("enso", code);
var run = module.invokeMember("eval_expression", "run");
var res = run.execute(10000);
assertEquals("Array of the requested size computed", 10000, res.getArraySize());
Predicate<SourceSection> isLambda = (ss) -> {
var st = ss.getCharacters().toString();
return st.contains("noise") && !st.contains("map");
};
assertAvoidIdInstrumentationTag(isLambda);
}
private void assertAvoidIdInstrumentationTag(Predicate<SourceSection> isLambda) {
var found = nodes.assertNewNodes("Give me nodes", 0, 10000);
var err = new StringBuilder();
var missingTagInLambda = false;
for (var nn : found.values()) {
for (var n : nn) {
var ss = n.getSourceSection();
if (ss == null) {
continue;
}
if (isLambda.test(ss)) {
err.append("\n").append("code: ").append(ss.getCharacters()).append(" for node ").append(n.getClass().getName());
if (n instanceof InstrumentableNode in) {
final boolean hasAvoidIdInstrumentationTag = in.hasTag(AvoidIdInstrumentationTag.class);
if (!hasAvoidIdInstrumentationTag) {
missingTagInLambda = true;
}
err.append("\n").append(" AvoidIdInstrumentationTag: ").append(hasAvoidIdInstrumentationTag);
err.append("\n").append(" IdentifiedTag: ").append(in.hasTag(IdentifiedTag.class));
err.append("\n").append(" ExpressionTag: ").append(in.hasTag(StandardTags.ExpressionTag.class));
err.append("\n").append(" RootNode: ").append(n.getRootNode());
if (n.getRootNode() instanceof ClosureRootNode crn) {
err.append("\n").append(" ClosureRootNode.subject to instr: ").append(crn.isSubjectToInstrumentation());
err.append("\n").append(" ClosureRootNode.used in bindings: ").append(crn.isUsedInBinding());
}
}
}
}
}
if (missingTagInLambda) {
fail(err.toString());
}
}
}

View File

@ -21,6 +21,8 @@ import org.enso.interpreter.runtime.tag.IdentifiedTag;
import java.util.Arrays;
import java.util.UUID;
import org.enso.interpreter.node.ClosureRootNode;
import org.enso.interpreter.runtime.tag.AvoidIdInstrumentationTag;
/**
* A node used for instrumenting function calls. It does nothing useful from the language
@ -151,6 +153,9 @@ public class FunctionCallInstrumentationNode extends Node implements Instrumenta
*/
@Override
public boolean hasTag(Class<? extends Tag> tag) {
if (AvoidIdInstrumentationTag.class == tag) {
return getRootNode() instanceof ClosureRootNode c && !c.isSubjectToInstrumentation();
}
return tag == StandardTags.CallTag.class || (tag == IdentifiedTag.class && id != null);
}

View File

@ -4,7 +4,9 @@ import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.instrumentation.StandardTags;
import com.oracle.truffle.api.instrumentation.Tag;
import com.oracle.truffle.api.source.SourceSection;
import org.enso.interpreter.node.ClosureRootNode;
import org.enso.interpreter.node.ExpressionNode;
import org.enso.interpreter.runtime.tag.AvoidIdInstrumentationTag;
/**
* Node tagged with {@link StandardTags.StatementTag}. Inserted by {@link BlockNode} into the AST
@ -42,6 +44,9 @@ final class StatementNode extends ExpressionNode {
@Override
public boolean hasTag(Class<? extends Tag> tag) {
if (AvoidIdInstrumentationTag.class == tag) {
return getRootNode() instanceof ClosureRootNode c && !c.isSubjectToInstrumentation();
}
return StandardTags.StatementTag.class == tag;
}
}

View File

@ -391,7 +391,8 @@ class IrToTruffle(
new expressionProcessor.BuildFunctionBody(
fn.arguments,
fn.body,
effectContext
effectContext,
true
)
val rootNode = MethodRootNode.build(
language,
@ -472,7 +473,8 @@ class IrToTruffle(
new expressionProcessor.BuildFunctionBody(
fn.arguments,
fn.body,
None
None,
true
)
val rootNode = MethodRootNode.build(
language,
@ -763,13 +765,18 @@ class IrToTruffle(
* @param ir the IR to generate code for
* @return a truffle expression that represents the same program as `ir`
*/
def run(ir: IR.Expression): RuntimeExpression = run(ir, false)
def run(ir: IR.Expression): RuntimeExpression = run(ir, false, true)
private def run(ir: IR.Expression, binding: Boolean): RuntimeExpression = {
private def run(
ir: IR.Expression,
binding: Boolean,
subjectToInstrumentation: Boolean
): RuntimeExpression = {
val runtimeExpression = ir match {
case block: IR.Expression.Block => processBlock(block)
case literal: IR.Literal => processLiteral(literal)
case app: IR.Application => processApplication(app)
case block: IR.Expression.Block => processBlock(block)
case literal: IR.Literal => processLiteral(literal)
case app: IR.Application =>
processApplication(app, subjectToInstrumentation)
case name: IR.Name => processName(name)
case function: IR.Function => processFunction(function, binding)
case binding: IR.Expression.Binding => processBinding(binding)
@ -1215,7 +1222,7 @@ class IrToTruffle(
val slotIdx = scope.getVarSlotIdx(occInfo.id)
setLocation(
AssignmentNode.build(this.run(binding.expression, true), slotIdx),
AssignmentNode.build(this.run(binding.expression, true, true), slotIdx),
binding.location
)
}
@ -1464,7 +1471,8 @@ class IrToTruffle(
class BuildFunctionBody(
val arguments: List[IR.DefinitionArgument],
val body: IR.Expression,
val effectContext: Option[String]
val effectContext: Option[String],
val subjectToInstrumentation: Boolean
) {
private val argFactory = new DefinitionArgumentProcessor(scopeName, scope)
private lazy val slots = computeSlots()
@ -1484,7 +1492,8 @@ class IrToTruffle(
arguments.map(_.name.name),
argSlotIdxs
)
case _ => ExpressionProcessor.this.run(body)
case _ =>
ExpressionProcessor.this.run(body, false, subjectToInstrumentation)
}
val block = BlockNode.build(argExpressions.toArray, bodyExpr)
effectContext match {
@ -1570,7 +1579,7 @@ class IrToTruffle(
location: Option[IdentifiedLocation],
binding: Boolean = false
): CreateFunctionNode = {
val bodyBuilder = new BuildFunctionBody(arguments, body, None)
val bodyBuilder = new BuildFunctionBody(arguments, body, None, binding)
val fnRootNode = ClosureRootNode.build(
language,
scope,
@ -1617,12 +1626,15 @@ class IrToTruffle(
* @param application the function application to generate code for
* @return the truffle nodes corresponding to `application`
*/
def processApplication(application: IR.Application): RuntimeExpression =
def processApplication(
application: IR.Application,
subjectToInstrumentation: Boolean
): RuntimeExpression =
application match {
case IR.Application.Prefix(fn, Nil, true, _, _, _) =>
run(fn)
case app: IR.Application.Prefix =>
processApplicationWithArgs(app)
processApplicationWithArgs(app, subjectToInstrumentation)
case IR.Application.Force(expr, location, _, _) =>
setLocation(ForceNode.build(this.run(expr)), location)
case IR.Application.Literal.Sequence(items, location, _, _) =>
@ -1653,7 +1665,8 @@ class IrToTruffle(
}
private def processApplicationWithArgs(
application: IR.Application.Prefix
application: IR.Application.Prefix,
subjectToInstrumentation: Boolean
): RuntimeExpression = {
val IR.Application.Prefix(fn, args, hasDefaultsSuspended, loc, _, _) =
application
@ -1663,7 +1676,8 @@ class IrToTruffle(
val callArgs = new ArrayBuffer[CallArgument]()
for ((unprocessedArg, position) <- arguments.view.zipWithIndex) {
val arg = callArgFactory.run(unprocessedArg, position)
val arg =
callArgFactory.run(unprocessedArg, position, subjectToInstrumentation)
callArgs.append(arg)
}
@ -1714,7 +1728,11 @@ class IrToTruffle(
* @return a truffle construct corresponding to the argument definition
* `arg`
*/
def run(arg: IR.CallArgument, position: Int): CallArgument =
def run(
arg: IR.CallArgument,
position: Int,
subjectToInstrumentation: Boolean
): CallArgument =
arg match {
case IR.CallArgument.Specified(
name,
@ -1765,7 +1783,7 @@ class IrToTruffle(
argumentExpression,
section,
displayName,
true,
subjectToInstrumentation,
false
)
val callTarget = closureRootNode.getCallTarget