From 53efcf0a17f8bbc29e573d67a515e211a7bba2da Mon Sep 17 00:00:00 2001 From: Jaroslav Tulach Date: Mon, 30 Jan 2023 06:48:19 +0100 Subject: [PATCH] Propagate subjectToInstrumentation flag via ExpressionProcessor (#4090) `subjectToInstrumentation` needs to be propagated via `ExpressionProcessor`. --- .../AvoidIdInstrumentationTagTest.java | 58 +++++++++++++++---- .../enso/compiler/codegen/IrToTruffle.scala | 44 +++++++++----- 2 files changed, 75 insertions(+), 27 deletions(-) diff --git a/engine/runtime-with-instruments/src/test/java/org/enso/interpreter/test/instrument/AvoidIdInstrumentationTagTest.java b/engine/runtime-with-instruments/src/test/java/org/enso/interpreter/test/instrument/AvoidIdInstrumentationTagTest.java index b9f1b8773a..5852493006 100644 --- a/engine/runtime-with-instruments/src/test/java/org/enso/interpreter/test/instrument/AvoidIdInstrumentationTagTest.java +++ b/engine/runtime-with-instruments/src/test/java/org/enso/interpreter/test/instrument/AvoidIdInstrumentationTagTest.java @@ -1,6 +1,7 @@ package org.enso.interpreter.test.instrument; import com.oracle.truffle.api.instrumentation.InstrumentableNode; import com.oracle.truffle.api.instrumentation.StandardTags; +import com.oracle.truffle.api.nodes.RootNode; import com.oracle.truffle.api.source.SourceSection; import java.io.OutputStream; import java.nio.file.Paths; @@ -17,6 +18,7 @@ import org.graalvm.polyglot.Source; import org.junit.After; import org.junit.Assert; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.fail; import org.junit.Before; @@ -77,6 +79,33 @@ public class AvoidIdInstrumentationTagTest { assertAvoidIdInstrumentationTag(isLambda); } + @Test + public void avoidIdInstrumentationInLambdaMapFunctionYear2010() throws Exception { + var code = """ + from Standard.Base import all + + operator13 = [ 1973, 1975, 2005, 2006 ] + operator15 = operator13.map year-> if year < 2000 then [255, 100] else if year < 2010 then [0, 255] else [0, 100] + """; + var src = Source.newBuilder("enso", code, "YearLambda.enso").build(); + var module = context.eval(src); + var res = module.invokeMember("eval_expression", "operator15"); + assertEquals("Array of the requested size computed", 4, res.getArraySize()); + for (var i = 0; i < res.getArraySize(); i++) { + var element = res.getArrayElement(i); + assertTrue("Also array", element.hasArrayElements()); + assertEquals("Size is 2", 2, element.getArraySize()); + } + + Predicate isLambda = (ss) -> { + var sameSrc = ss.getSource().getCharacters().toString().equals(src.getCharacters().toString()); + var st = ss.getCharacters().toString(); + return sameSrc && st.contains("2010") && !st.contains("map"); + }; + + assertAvoidIdInstrumentationTag(isLambda); + } + private void assertAvoidIdInstrumentationTag(Predicate isLambda) { var found = nodes.assertNewNodes("Give me nodes", 0, 10000); var err = new StringBuilder(); @@ -91,21 +120,11 @@ public class AvoidIdInstrumentationTagTest { if (isLambda.test(ss)) { err.append("\n").append("code: ").append(ss.getCharacters()).append(" for node ").append(n.getClass().getName()); if (n instanceof InstrumentableNode in) { - final boolean hasAvoidIdInstrumentationTag = in.hasTag(AvoidIdInstrumentationTag.class); - if (!hasAvoidIdInstrumentationTag) { + if (!hasAvoidIdInstrumentationTag(err, in, n.getRootNode())) { missingTagInLambda = true; } else { count++; } - - err.append("\n").append(" AvoidIdInstrumentationTag: ").append(hasAvoidIdInstrumentationTag); - err.append("\n").append(" IdentifiedTag: ").append(in.hasTag(IdentifiedTag.class)); - err.append("\n").append(" ExpressionTag: ").append(in.hasTag(StandardTags.ExpressionTag.class)); - err.append("\n").append(" RootNode: ").append(n.getRootNode()); - if (n.getRootNode() instanceof ClosureRootNode crn) { - err.append("\n").append(" ClosureRootNode.subject to instr: ").append(crn.isSubjectToInstrumentation()); - err.append("\n").append(" ClosureRootNode.used in bindings: ").append(crn.isUsedInBinding()); - } } } } @@ -115,4 +134,21 @@ public class AvoidIdInstrumentationTagTest { } assertNotEquals("Found some nodes", 0, count); } + + private boolean hasAvoidIdInstrumentationTag(StringBuilder err, InstrumentableNode in, RootNode rn) { + var hasAvoidIdInstrumentationTag = in.hasTag(AvoidIdInstrumentationTag.class); + if (!hasAvoidIdInstrumentationTag) { + err.append("\nERROR!"); + } + + err.append("\n").append(" AvoidIdInstrumentationTag: ").append(hasAvoidIdInstrumentationTag); + err.append("\n").append(" IdentifiedTag: ").append(in.hasTag(IdentifiedTag.class)); + err.append("\n").append(" ExpressionTag: ").append(in.hasTag(StandardTags.ExpressionTag.class)); + err.append("\n").append(" RootNode: ").append(rn); + if (rn instanceof ClosureRootNode crn) { + err.append("\n").append(" ClosureRootNode.subject to instr: ").append(crn.isSubjectToInstrumentation()); + err.append("\n").append(" ClosureRootNode.used in bindings: ").append(crn.isUsedInBinding()); + } + return hasAvoidIdInstrumentationTag; + } } diff --git a/engine/runtime/src/main/scala/org/enso/compiler/codegen/IrToTruffle.scala b/engine/runtime/src/main/scala/org/enso/compiler/codegen/IrToTruffle.scala index 421c503d26..24c526ed80 100644 --- a/engine/runtime/src/main/scala/org/enso/compiler/codegen/IrToTruffle.scala +++ b/engine/runtime/src/main/scala/org/enso/compiler/codegen/IrToTruffle.scala @@ -276,7 +276,7 @@ class IrToTruffle( dataflowInfo ) val expressionNode = - expressionProcessor.run(annotation.expression) + expressionProcessor.run(annotation.expression, true) val closureName = s"" val closureRootNode = ClosureRootNode.build( language, @@ -473,7 +473,7 @@ class IrToTruffle( dataflowInfo ) val expressionNode = - expressionProcessor.run(annotation.expression) + expressionProcessor.run(annotation.expression, true) val closureName = s"" val closureRootNode = ClosureRootNode.build( @@ -848,9 +848,13 @@ class IrToTruffle( /** Runs the code generation process on the provided piece of [[IR]]. * * @param ir the IR to generate code for + * @param subjectToInstrumentation value of subject to instrumentation * @return a truffle expression that represents the same program as `ir` */ - def run(ir: IR.Expression): RuntimeExpression = run(ir, false, true) + def run( + ir: IR.Expression, + subjectToInstrumentation: Boolean + ): RuntimeExpression = run(ir, false, subjectToInstrumentation) private def run( ir: IR.Expression, @@ -865,8 +869,9 @@ class IrToTruffle( case name: IR.Name => processName(name) case function: IR.Function => processFunction(function, binding) case binding: IR.Expression.Binding => processBinding(binding) - case caseExpr: IR.Case => processCase(caseExpr) - case typ: IR.Type => processType(typ) + case caseExpr: IR.Case => + processCase(caseExpr, subjectToInstrumentation) + case typ: IR.Type => processType(typ) case _: IR.Empty => throw new CompilerError( "Empty IR nodes should not exist during code generation." @@ -893,7 +898,7 @@ class IrToTruffle( * @return a truffle expression that represents the same program as `ir` */ def runInline(ir: IR.Expression): RuntimeExpression = { - val expression = run(ir) + val expression = run(ir, false) expression } @@ -932,8 +937,8 @@ class IrToTruffle( val callTarget = defaultRootNode.getCallTarget setLocation(CreateThunkNode.build(callTarget), block.location) } else { - val statementExprs = block.expressions.map(this.run).toArray - val retExpr = this.run(block.returnValue) + val statementExprs = block.expressions.map(this.run(_, true)).toArray + val retExpr = this.run(block.returnValue, true) val blockNode = BlockNode.build(statementExprs, retExpr) setLocation(blockNode, block.location) @@ -965,10 +970,13 @@ class IrToTruffle( * @param caseExpr the case expression to generate code for * @return the truffle nodes corresponding to `caseExpr` */ - def processCase(caseExpr: IR.Case): RuntimeExpression = + def processCase( + caseExpr: IR.Case, + subjectToInstrumentation: Boolean + ): RuntimeExpression = caseExpr match { case IR.Case.Expr(scrutinee, branches, isNested, location, _, _) => - val scrutineeNode = this.run(scrutinee) + val scrutineeNode = this.run(scrutinee, subjectToInstrumentation) val maybeCases = branches.map(processCaseBranch) val allCasesValid = maybeCases.forall(_.isRight) @@ -1746,13 +1754,16 @@ class IrToTruffle( ): RuntimeExpression = application match { case IR.Application.Prefix(fn, Nil, true, _, _, _) => - run(fn) + run(fn, subjectToInstrumentation) case app: IR.Application.Prefix => processApplicationWithArgs(app, subjectToInstrumentation) case IR.Application.Force(expr, location, _, _) => - setLocation(ForceNode.build(this.run(expr)), location) + setLocation( + ForceNode.build(this.run(expr, subjectToInstrumentation)), + location + ) case IR.Application.Literal.Sequence(items, location, _, _) => - val itemNodes = items.map(run).toArray + val itemNodes = items.map(run(_, subjectToInstrumentation)).toArray setLocation(SequenceLiteralNode.build(itemNodes), location) case _: IR.Application.Literal.Typeset => setLocation( @@ -1808,7 +1819,7 @@ class IrToTruffle( createOptimised(moduleScope)(scope)(callArgs.toList) case _ => ApplicationNode.build( - this.run(fn), + this.run(fn, subjectToInstrumentation), callArgs.toArray, defaultsExecutionMode ) @@ -1876,7 +1887,8 @@ class IrToTruffle( scope.createChild(scopeInfo.scope, flattenToParent = true) } val argumentExpression = - new ExpressionProcessor(childScope, scopeName).run(value) + new ExpressionProcessor(childScope, scopeName) + .run(value, subjectToInstrumentation) val result = if (!shouldCreateClosureRootNode) { argumentExpression @@ -1963,7 +1975,7 @@ class IrToTruffle( inputArg match { case arg: IR.DefinitionArgument.Specified => val defaultExpression = arg.defaultValue - .map(new ExpressionProcessor(scope, scopeName).run(_)) + .map(new ExpressionProcessor(scope, scopeName).run(_, false)) .orNull // Note [Handling Suspended Defaults]