Propagate subjectToInstrumentation flag via ExpressionProcessor (#4090)

`subjectToInstrumentation` needs to be propagated via `ExpressionProcessor`.
This commit is contained in:
Jaroslav Tulach 2023-01-30 06:48:19 +01:00 committed by GitHub
parent 7e8f49e86f
commit 53efcf0a17
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 75 additions and 27 deletions

View File

@ -1,6 +1,7 @@
package org.enso.interpreter.test.instrument;
import com.oracle.truffle.api.instrumentation.InstrumentableNode;
import com.oracle.truffle.api.instrumentation.StandardTags;
import com.oracle.truffle.api.nodes.RootNode;
import com.oracle.truffle.api.source.SourceSection;
import java.io.OutputStream;
import java.nio.file.Paths;
@ -17,6 +18,7 @@ import org.graalvm.polyglot.Source;
import org.junit.After;
import org.junit.Assert;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.fail;
import org.junit.Before;
@ -77,6 +79,33 @@ public class AvoidIdInstrumentationTagTest {
assertAvoidIdInstrumentationTag(isLambda);
}
@Test
public void avoidIdInstrumentationInLambdaMapFunctionYear2010() throws Exception {
var code = """
from Standard.Base import all
operator13 = [ 1973, 1975, 2005, 2006 ]
operator15 = operator13.map year-> if year < 2000 then [255, 100] else if year < 2010 then [0, 255] else [0, 100]
""";
var src = Source.newBuilder("enso", code, "YearLambda.enso").build();
var module = context.eval(src);
var res = module.invokeMember("eval_expression", "operator15");
assertEquals("Array of the requested size computed", 4, res.getArraySize());
for (var i = 0; i < res.getArraySize(); i++) {
var element = res.getArrayElement(i);
assertTrue("Also array", element.hasArrayElements());
assertEquals("Size is 2", 2, element.getArraySize());
}
Predicate<SourceSection> isLambda = (ss) -> {
var sameSrc = ss.getSource().getCharacters().toString().equals(src.getCharacters().toString());
var st = ss.getCharacters().toString();
return sameSrc && st.contains("2010") && !st.contains("map");
};
assertAvoidIdInstrumentationTag(isLambda);
}
private void assertAvoidIdInstrumentationTag(Predicate<SourceSection> isLambda) {
var found = nodes.assertNewNodes("Give me nodes", 0, 10000);
var err = new StringBuilder();
@ -91,21 +120,11 @@ public class AvoidIdInstrumentationTagTest {
if (isLambda.test(ss)) {
err.append("\n").append("code: ").append(ss.getCharacters()).append(" for node ").append(n.getClass().getName());
if (n instanceof InstrumentableNode in) {
final boolean hasAvoidIdInstrumentationTag = in.hasTag(AvoidIdInstrumentationTag.class);
if (!hasAvoidIdInstrumentationTag) {
if (!hasAvoidIdInstrumentationTag(err, in, n.getRootNode())) {
missingTagInLambda = true;
} else {
count++;
}
err.append("\n").append(" AvoidIdInstrumentationTag: ").append(hasAvoidIdInstrumentationTag);
err.append("\n").append(" IdentifiedTag: ").append(in.hasTag(IdentifiedTag.class));
err.append("\n").append(" ExpressionTag: ").append(in.hasTag(StandardTags.ExpressionTag.class));
err.append("\n").append(" RootNode: ").append(n.getRootNode());
if (n.getRootNode() instanceof ClosureRootNode crn) {
err.append("\n").append(" ClosureRootNode.subject to instr: ").append(crn.isSubjectToInstrumentation());
err.append("\n").append(" ClosureRootNode.used in bindings: ").append(crn.isUsedInBinding());
}
}
}
}
@ -115,4 +134,21 @@ public class AvoidIdInstrumentationTagTest {
}
assertNotEquals("Found some nodes", 0, count);
}
private boolean hasAvoidIdInstrumentationTag(StringBuilder err, InstrumentableNode in, RootNode rn) {
var hasAvoidIdInstrumentationTag = in.hasTag(AvoidIdInstrumentationTag.class);
if (!hasAvoidIdInstrumentationTag) {
err.append("\nERROR!");
}
err.append("\n").append(" AvoidIdInstrumentationTag: ").append(hasAvoidIdInstrumentationTag);
err.append("\n").append(" IdentifiedTag: ").append(in.hasTag(IdentifiedTag.class));
err.append("\n").append(" ExpressionTag: ").append(in.hasTag(StandardTags.ExpressionTag.class));
err.append("\n").append(" RootNode: ").append(rn);
if (rn instanceof ClosureRootNode crn) {
err.append("\n").append(" ClosureRootNode.subject to instr: ").append(crn.isSubjectToInstrumentation());
err.append("\n").append(" ClosureRootNode.used in bindings: ").append(crn.isUsedInBinding());
}
return hasAvoidIdInstrumentationTag;
}
}

View File

@ -276,7 +276,7 @@ class IrToTruffle(
dataflowInfo
)
val expressionNode =
expressionProcessor.run(annotation.expression)
expressionProcessor.run(annotation.expression, true)
val closureName = s"<default::$scopeName>"
val closureRootNode = ClosureRootNode.build(
language,
@ -473,7 +473,7 @@ class IrToTruffle(
dataflowInfo
)
val expressionNode =
expressionProcessor.run(annotation.expression)
expressionProcessor.run(annotation.expression, true)
val closureName =
s"<default::${expressionProcessor.scopeName}>"
val closureRootNode = ClosureRootNode.build(
@ -848,9 +848,13 @@ class IrToTruffle(
/** Runs the code generation process on the provided piece of [[IR]].
*
* @param ir the IR to generate code for
* @param subjectToInstrumentation value of subject to instrumentation
* @return a truffle expression that represents the same program as `ir`
*/
def run(ir: IR.Expression): RuntimeExpression = run(ir, false, true)
def run(
ir: IR.Expression,
subjectToInstrumentation: Boolean
): RuntimeExpression = run(ir, false, subjectToInstrumentation)
private def run(
ir: IR.Expression,
@ -865,8 +869,9 @@ class IrToTruffle(
case name: IR.Name => processName(name)
case function: IR.Function => processFunction(function, binding)
case binding: IR.Expression.Binding => processBinding(binding)
case caseExpr: IR.Case => processCase(caseExpr)
case typ: IR.Type => processType(typ)
case caseExpr: IR.Case =>
processCase(caseExpr, subjectToInstrumentation)
case typ: IR.Type => processType(typ)
case _: IR.Empty =>
throw new CompilerError(
"Empty IR nodes should not exist during code generation."
@ -893,7 +898,7 @@ class IrToTruffle(
* @return a truffle expression that represents the same program as `ir`
*/
def runInline(ir: IR.Expression): RuntimeExpression = {
val expression = run(ir)
val expression = run(ir, false)
expression
}
@ -932,8 +937,8 @@ class IrToTruffle(
val callTarget = defaultRootNode.getCallTarget
setLocation(CreateThunkNode.build(callTarget), block.location)
} else {
val statementExprs = block.expressions.map(this.run).toArray
val retExpr = this.run(block.returnValue)
val statementExprs = block.expressions.map(this.run(_, true)).toArray
val retExpr = this.run(block.returnValue, true)
val blockNode = BlockNode.build(statementExprs, retExpr)
setLocation(blockNode, block.location)
@ -965,10 +970,13 @@ class IrToTruffle(
* @param caseExpr the case expression to generate code for
* @return the truffle nodes corresponding to `caseExpr`
*/
def processCase(caseExpr: IR.Case): RuntimeExpression =
def processCase(
caseExpr: IR.Case,
subjectToInstrumentation: Boolean
): RuntimeExpression =
caseExpr match {
case IR.Case.Expr(scrutinee, branches, isNested, location, _, _) =>
val scrutineeNode = this.run(scrutinee)
val scrutineeNode = this.run(scrutinee, subjectToInstrumentation)
val maybeCases = branches.map(processCaseBranch)
val allCasesValid = maybeCases.forall(_.isRight)
@ -1746,13 +1754,16 @@ class IrToTruffle(
): RuntimeExpression =
application match {
case IR.Application.Prefix(fn, Nil, true, _, _, _) =>
run(fn)
run(fn, subjectToInstrumentation)
case app: IR.Application.Prefix =>
processApplicationWithArgs(app, subjectToInstrumentation)
case IR.Application.Force(expr, location, _, _) =>
setLocation(ForceNode.build(this.run(expr)), location)
setLocation(
ForceNode.build(this.run(expr, subjectToInstrumentation)),
location
)
case IR.Application.Literal.Sequence(items, location, _, _) =>
val itemNodes = items.map(run).toArray
val itemNodes = items.map(run(_, subjectToInstrumentation)).toArray
setLocation(SequenceLiteralNode.build(itemNodes), location)
case _: IR.Application.Literal.Typeset =>
setLocation(
@ -1808,7 +1819,7 @@ class IrToTruffle(
createOptimised(moduleScope)(scope)(callArgs.toList)
case _ =>
ApplicationNode.build(
this.run(fn),
this.run(fn, subjectToInstrumentation),
callArgs.toArray,
defaultsExecutionMode
)
@ -1876,7 +1887,8 @@ class IrToTruffle(
scope.createChild(scopeInfo.scope, flattenToParent = true)
}
val argumentExpression =
new ExpressionProcessor(childScope, scopeName).run(value)
new ExpressionProcessor(childScope, scopeName)
.run(value, subjectToInstrumentation)
val result = if (!shouldCreateClosureRootNode) {
argumentExpression
@ -1963,7 +1975,7 @@ class IrToTruffle(
inputArg match {
case arg: IR.DefinitionArgument.Specified =>
val defaultExpression = arg.defaultValue
.map(new ExpressionProcessor(scope, scopeName).run(_))
.map(new ExpressionProcessor(scope, scopeName).run(_, false))
.orNull
// Note [Handling Suspended Defaults]