Propagate subjectToInstrumentation flag via ExpressionProcessor (#4090)

`subjectToInstrumentation` needs to be propagated via `ExpressionProcessor`.
This commit is contained in:
Jaroslav Tulach 2023-01-30 06:48:19 +01:00 committed by GitHub
parent 7e8f49e86f
commit 53efcf0a17
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 75 additions and 27 deletions

View File

@ -1,6 +1,7 @@
package org.enso.interpreter.test.instrument; package org.enso.interpreter.test.instrument;
import com.oracle.truffle.api.instrumentation.InstrumentableNode; import com.oracle.truffle.api.instrumentation.InstrumentableNode;
import com.oracle.truffle.api.instrumentation.StandardTags; import com.oracle.truffle.api.instrumentation.StandardTags;
import com.oracle.truffle.api.nodes.RootNode;
import com.oracle.truffle.api.source.SourceSection; import com.oracle.truffle.api.source.SourceSection;
import java.io.OutputStream; import java.io.OutputStream;
import java.nio.file.Paths; import java.nio.file.Paths;
@ -17,6 +18,7 @@ import org.graalvm.polyglot.Source;
import org.junit.After; import org.junit.After;
import org.junit.Assert; import org.junit.Assert;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
import org.junit.Before; import org.junit.Before;
@ -77,6 +79,33 @@ public class AvoidIdInstrumentationTagTest {
assertAvoidIdInstrumentationTag(isLambda); assertAvoidIdInstrumentationTag(isLambda);
} }
@Test
public void avoidIdInstrumentationInLambdaMapFunctionYear2010() throws Exception {
var code = """
from Standard.Base import all
operator13 = [ 1973, 1975, 2005, 2006 ]
operator15 = operator13.map year-> if year < 2000 then [255, 100] else if year < 2010 then [0, 255] else [0, 100]
""";
var src = Source.newBuilder("enso", code, "YearLambda.enso").build();
var module = context.eval(src);
var res = module.invokeMember("eval_expression", "operator15");
assertEquals("Array of the requested size computed", 4, res.getArraySize());
for (var i = 0; i < res.getArraySize(); i++) {
var element = res.getArrayElement(i);
assertTrue("Also array", element.hasArrayElements());
assertEquals("Size is 2", 2, element.getArraySize());
}
Predicate<SourceSection> isLambda = (ss) -> {
var sameSrc = ss.getSource().getCharacters().toString().equals(src.getCharacters().toString());
var st = ss.getCharacters().toString();
return sameSrc && st.contains("2010") && !st.contains("map");
};
assertAvoidIdInstrumentationTag(isLambda);
}
private void assertAvoidIdInstrumentationTag(Predicate<SourceSection> isLambda) { private void assertAvoidIdInstrumentationTag(Predicate<SourceSection> isLambda) {
var found = nodes.assertNewNodes("Give me nodes", 0, 10000); var found = nodes.assertNewNodes("Give me nodes", 0, 10000);
var err = new StringBuilder(); var err = new StringBuilder();
@ -91,21 +120,11 @@ public class AvoidIdInstrumentationTagTest {
if (isLambda.test(ss)) { if (isLambda.test(ss)) {
err.append("\n").append("code: ").append(ss.getCharacters()).append(" for node ").append(n.getClass().getName()); err.append("\n").append("code: ").append(ss.getCharacters()).append(" for node ").append(n.getClass().getName());
if (n instanceof InstrumentableNode in) { if (n instanceof InstrumentableNode in) {
final boolean hasAvoidIdInstrumentationTag = in.hasTag(AvoidIdInstrumentationTag.class); if (!hasAvoidIdInstrumentationTag(err, in, n.getRootNode())) {
if (!hasAvoidIdInstrumentationTag) {
missingTagInLambda = true; missingTagInLambda = true;
} else { } else {
count++; count++;
} }
err.append("\n").append(" AvoidIdInstrumentationTag: ").append(hasAvoidIdInstrumentationTag);
err.append("\n").append(" IdentifiedTag: ").append(in.hasTag(IdentifiedTag.class));
err.append("\n").append(" ExpressionTag: ").append(in.hasTag(StandardTags.ExpressionTag.class));
err.append("\n").append(" RootNode: ").append(n.getRootNode());
if (n.getRootNode() instanceof ClosureRootNode crn) {
err.append("\n").append(" ClosureRootNode.subject to instr: ").append(crn.isSubjectToInstrumentation());
err.append("\n").append(" ClosureRootNode.used in bindings: ").append(crn.isUsedInBinding());
}
} }
} }
} }
@ -115,4 +134,21 @@ public class AvoidIdInstrumentationTagTest {
} }
assertNotEquals("Found some nodes", 0, count); assertNotEquals("Found some nodes", 0, count);
} }
private boolean hasAvoidIdInstrumentationTag(StringBuilder err, InstrumentableNode in, RootNode rn) {
var hasAvoidIdInstrumentationTag = in.hasTag(AvoidIdInstrumentationTag.class);
if (!hasAvoidIdInstrumentationTag) {
err.append("\nERROR!");
}
err.append("\n").append(" AvoidIdInstrumentationTag: ").append(hasAvoidIdInstrumentationTag);
err.append("\n").append(" IdentifiedTag: ").append(in.hasTag(IdentifiedTag.class));
err.append("\n").append(" ExpressionTag: ").append(in.hasTag(StandardTags.ExpressionTag.class));
err.append("\n").append(" RootNode: ").append(rn);
if (rn instanceof ClosureRootNode crn) {
err.append("\n").append(" ClosureRootNode.subject to instr: ").append(crn.isSubjectToInstrumentation());
err.append("\n").append(" ClosureRootNode.used in bindings: ").append(crn.isUsedInBinding());
}
return hasAvoidIdInstrumentationTag;
}
} }

View File

@ -276,7 +276,7 @@ class IrToTruffle(
dataflowInfo dataflowInfo
) )
val expressionNode = val expressionNode =
expressionProcessor.run(annotation.expression) expressionProcessor.run(annotation.expression, true)
val closureName = s"<default::$scopeName>" val closureName = s"<default::$scopeName>"
val closureRootNode = ClosureRootNode.build( val closureRootNode = ClosureRootNode.build(
language, language,
@ -473,7 +473,7 @@ class IrToTruffle(
dataflowInfo dataflowInfo
) )
val expressionNode = val expressionNode =
expressionProcessor.run(annotation.expression) expressionProcessor.run(annotation.expression, true)
val closureName = val closureName =
s"<default::${expressionProcessor.scopeName}>" s"<default::${expressionProcessor.scopeName}>"
val closureRootNode = ClosureRootNode.build( val closureRootNode = ClosureRootNode.build(
@ -848,9 +848,13 @@ class IrToTruffle(
/** Runs the code generation process on the provided piece of [[IR]]. /** Runs the code generation process on the provided piece of [[IR]].
* *
* @param ir the IR to generate code for * @param ir the IR to generate code for
* @param subjectToInstrumentation value of subject to instrumentation
* @return a truffle expression that represents the same program as `ir` * @return a truffle expression that represents the same program as `ir`
*/ */
def run(ir: IR.Expression): RuntimeExpression = run(ir, false, true) def run(
ir: IR.Expression,
subjectToInstrumentation: Boolean
): RuntimeExpression = run(ir, false, subjectToInstrumentation)
private def run( private def run(
ir: IR.Expression, ir: IR.Expression,
@ -865,8 +869,9 @@ class IrToTruffle(
case name: IR.Name => processName(name) case name: IR.Name => processName(name)
case function: IR.Function => processFunction(function, binding) case function: IR.Function => processFunction(function, binding)
case binding: IR.Expression.Binding => processBinding(binding) case binding: IR.Expression.Binding => processBinding(binding)
case caseExpr: IR.Case => processCase(caseExpr) case caseExpr: IR.Case =>
case typ: IR.Type => processType(typ) processCase(caseExpr, subjectToInstrumentation)
case typ: IR.Type => processType(typ)
case _: IR.Empty => case _: IR.Empty =>
throw new CompilerError( throw new CompilerError(
"Empty IR nodes should not exist during code generation." "Empty IR nodes should not exist during code generation."
@ -893,7 +898,7 @@ class IrToTruffle(
* @return a truffle expression that represents the same program as `ir` * @return a truffle expression that represents the same program as `ir`
*/ */
def runInline(ir: IR.Expression): RuntimeExpression = { def runInline(ir: IR.Expression): RuntimeExpression = {
val expression = run(ir) val expression = run(ir, false)
expression expression
} }
@ -932,8 +937,8 @@ class IrToTruffle(
val callTarget = defaultRootNode.getCallTarget val callTarget = defaultRootNode.getCallTarget
setLocation(CreateThunkNode.build(callTarget), block.location) setLocation(CreateThunkNode.build(callTarget), block.location)
} else { } else {
val statementExprs = block.expressions.map(this.run).toArray val statementExprs = block.expressions.map(this.run(_, true)).toArray
val retExpr = this.run(block.returnValue) val retExpr = this.run(block.returnValue, true)
val blockNode = BlockNode.build(statementExprs, retExpr) val blockNode = BlockNode.build(statementExprs, retExpr)
setLocation(blockNode, block.location) setLocation(blockNode, block.location)
@ -965,10 +970,13 @@ class IrToTruffle(
* @param caseExpr the case expression to generate code for * @param caseExpr the case expression to generate code for
* @return the truffle nodes corresponding to `caseExpr` * @return the truffle nodes corresponding to `caseExpr`
*/ */
def processCase(caseExpr: IR.Case): RuntimeExpression = def processCase(
caseExpr: IR.Case,
subjectToInstrumentation: Boolean
): RuntimeExpression =
caseExpr match { caseExpr match {
case IR.Case.Expr(scrutinee, branches, isNested, location, _, _) => case IR.Case.Expr(scrutinee, branches, isNested, location, _, _) =>
val scrutineeNode = this.run(scrutinee) val scrutineeNode = this.run(scrutinee, subjectToInstrumentation)
val maybeCases = branches.map(processCaseBranch) val maybeCases = branches.map(processCaseBranch)
val allCasesValid = maybeCases.forall(_.isRight) val allCasesValid = maybeCases.forall(_.isRight)
@ -1746,13 +1754,16 @@ class IrToTruffle(
): RuntimeExpression = ): RuntimeExpression =
application match { application match {
case IR.Application.Prefix(fn, Nil, true, _, _, _) => case IR.Application.Prefix(fn, Nil, true, _, _, _) =>
run(fn) run(fn, subjectToInstrumentation)
case app: IR.Application.Prefix => case app: IR.Application.Prefix =>
processApplicationWithArgs(app, subjectToInstrumentation) processApplicationWithArgs(app, subjectToInstrumentation)
case IR.Application.Force(expr, location, _, _) => case IR.Application.Force(expr, location, _, _) =>
setLocation(ForceNode.build(this.run(expr)), location) setLocation(
ForceNode.build(this.run(expr, subjectToInstrumentation)),
location
)
case IR.Application.Literal.Sequence(items, location, _, _) => case IR.Application.Literal.Sequence(items, location, _, _) =>
val itemNodes = items.map(run).toArray val itemNodes = items.map(run(_, subjectToInstrumentation)).toArray
setLocation(SequenceLiteralNode.build(itemNodes), location) setLocation(SequenceLiteralNode.build(itemNodes), location)
case _: IR.Application.Literal.Typeset => case _: IR.Application.Literal.Typeset =>
setLocation( setLocation(
@ -1808,7 +1819,7 @@ class IrToTruffle(
createOptimised(moduleScope)(scope)(callArgs.toList) createOptimised(moduleScope)(scope)(callArgs.toList)
case _ => case _ =>
ApplicationNode.build( ApplicationNode.build(
this.run(fn), this.run(fn, subjectToInstrumentation),
callArgs.toArray, callArgs.toArray,
defaultsExecutionMode defaultsExecutionMode
) )
@ -1876,7 +1887,8 @@ class IrToTruffle(
scope.createChild(scopeInfo.scope, flattenToParent = true) scope.createChild(scopeInfo.scope, flattenToParent = true)
} }
val argumentExpression = val argumentExpression =
new ExpressionProcessor(childScope, scopeName).run(value) new ExpressionProcessor(childScope, scopeName)
.run(value, subjectToInstrumentation)
val result = if (!shouldCreateClosureRootNode) { val result = if (!shouldCreateClosureRootNode) {
argumentExpression argumentExpression
@ -1963,7 +1975,7 @@ class IrToTruffle(
inputArg match { inputArg match {
case arg: IR.DefinitionArgument.Specified => case arg: IR.DefinitionArgument.Specified =>
val defaultExpression = arg.defaultValue val defaultExpression = arg.defaultValue
.map(new ExpressionProcessor(scope, scopeName).run(_)) .map(new ExpressionProcessor(scope, scopeName).run(_, false))
.orNull .orNull
// Note [Handling Suspended Defaults] // Note [Handling Suspended Defaults]