Ensure new and wrapper nodes inherit UUID (#6067)

Instrumentation of calls involving warning values never really worked because:
1) newly created nodes didn't set the UUID of their children
2) the instrumentable wrappers always had an empty (i.e. null) UUID and
they never referred `get`/`setId` calls to their delegates

On the surface, everything worked fine. Except when one actually relied on the instrumentation of values with warnings for proper setup. Then no instrumentation (replacement of nodes) was performed due to empty UUID (as required by `hasTag` of `FunctionCallInstrumentationNode`).

Closes #6045. Discovered in #5893.
This commit is contained in:
Hubert Plociniczak 2023-03-27 19:49:20 +02:00 committed by GitHub
parent b977b5ac01
commit 76409b285d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 225 additions and 14 deletions

View File

@ -651,6 +651,7 @@
- [Use SHA-1 for calculating hashes of modules' IR and bindings][5791] - [Use SHA-1 for calculating hashes of modules' IR and bindings][5791]
- [Don't install Python component on Windows][5900] - [Don't install Python component on Windows][5900]
- [Detect potential name conflicts between exported types and FQNs][5966] - [Detect potential name conflicts between exported types and FQNs][5966]
- [Ensure calls involving warnings remain instrumented][6067]
[3227]: https://github.com/enso-org/enso/pull/3227 [3227]: https://github.com/enso-org/enso/pull/3227
[3248]: https://github.com/enso-org/enso/pull/3248 [3248]: https://github.com/enso-org/enso/pull/3248
@ -753,6 +754,7 @@
[5791]: https://github.com/enso-org/enso/pull/5791 [5791]: https://github.com/enso-org/enso/pull/5791
[5900]: https://github.com/enso-org/enso/pull/5900 [5900]: https://github.com/enso-org/enso/pull/5900
[5966]: https://github.com/enso-org/enso/pull/5966 [5966]: https://github.com/enso-org/enso/pull/5966
[6067]: https://github.com/enso-org/enso/pull/6067
# Enso 2.0.0-alpha.18 (2021-10-12) # Enso 2.0.0-alpha.18 (2021-10-12)

View File

@ -230,9 +230,9 @@ public class IdExecutionInstrument extends TruffleInstrument implements IdExecut
Node node = context.getInstrumentedNode(); Node node = context.getInstrumentedNode();
if (node instanceof FunctionCallInstrumentationNode if (node instanceof FunctionCallInstrumentationNode
&& result instanceof FunctionCallInstrumentationNode.FunctionCall) { && result instanceof FunctionCallInstrumentationNode.FunctionCall functionCall) {
UUID nodeId = ((FunctionCallInstrumentationNode) node).getId(); UUID nodeId = ((FunctionCallInstrumentationNode) node).getId();
onFunctionReturn(nodeId, result, context); onFunctionReturn(nodeId, functionCall, context);
} else if (node instanceof ExpressionNode) { } else if (node instanceof ExpressionNode) {
onExpressionReturn(result, node, context); onExpressionReturn(result, node, context);
} }
@ -307,11 +307,10 @@ public class IdExecutionInstrument extends TruffleInstrument implements IdExecut
} }
@CompilerDirectives.TruffleBoundary @CompilerDirectives.TruffleBoundary
private void onFunctionReturn(UUID nodeId, Object result, EventContext context) throws ThreadDeath { private void onFunctionReturn(UUID nodeId, FunctionCallInstrumentationNode.FunctionCall result, EventContext context) throws ThreadDeath {
calls.put( calls.put(
nodeId, new FunctionCallInfo((FunctionCallInstrumentationNode.FunctionCall) result)); nodeId, new FunctionCallInfo(result));
functionCallCallback.accept( functionCallCallback.accept(new ExpressionCall(nodeId, result));
new ExpressionCall(nodeId, (FunctionCallInstrumentationNode.FunctionCall) result));
// Return cached value after capturing the enterable function call in `functionCallCallback` // Return cached value after capturing the enterable function call in `functionCallCallback`
Object cachedResult = cache.get(nodeId); Object cachedResult = cache.get(nodeId);
if (cachedResult != null) { if (cachedResult != null) {

View File

@ -1,14 +1,22 @@
package org.enso.interpreter.test; package org.enso.interpreter.test;
import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.instrumentation.EventContext; import com.oracle.truffle.api.instrumentation.EventContext;
import com.oracle.truffle.api.instrumentation.ExecutionEventNode; import com.oracle.truffle.api.instrumentation.ExecutionEventNode;
import com.oracle.truffle.api.instrumentation.ExecutionEventNodeFactory; import com.oracle.truffle.api.instrumentation.ExecutionEventNodeFactory;
import com.oracle.truffle.api.instrumentation.SourceSectionFilter; import com.oracle.truffle.api.instrumentation.SourceSectionFilter;
import com.oracle.truffle.api.instrumentation.TruffleInstrument; import com.oracle.truffle.api.instrumentation.TruffleInstrument;
import com.oracle.truffle.api.nodes.Node; import com.oracle.truffle.api.nodes.Node;
import com.oracle.truffle.api.nodes.RootNode;
import com.oracle.truffle.api.source.SourceSection; import com.oracle.truffle.api.source.SourceSection;
import org.enso.interpreter.node.MethodRootNode;
import org.enso.interpreter.node.callable.FunctionCallInstrumentationNode;
import org.enso.pkg.QualifiedName;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CopyOnWriteArrayList;
import java.util.function.Function; import java.util.function.Function;
@ -22,6 +30,8 @@ public class NodeCountingTestInstrument extends TruffleInstrument {
public static final String INSTRUMENT_ID = "node-count-test"; public static final String INSTRUMENT_ID = "node-count-test";
private Map<Node, Node> all = new ConcurrentHashMap<>(); private Map<Node, Node> all = new ConcurrentHashMap<>();
private Map<Class, List<Node>> counter = new ConcurrentHashMap<>(); private Map<Class, List<Node>> counter = new ConcurrentHashMap<>();
private Map<UUID, FunctionCallInfo> calls = new ConcurrentHashMap<>();
private Env env; private Env env;
@Override @Override
@ -33,7 +43,17 @@ public class NodeCountingTestInstrument extends TruffleInstrument {
public void enable() { public void enable() {
this.env this.env
.getInstrumenter() .getInstrumenter()
.attachExecutionEventFactory(SourceSectionFilter.ANY, new CountingFactory()); .attachExecutionEventFactory(SourceSectionFilter.ANY, new CountingAndFunctionCallFactory());
}
public void enable(SourceSectionFilter filter) {
this.env
.getInstrumenter()
.attachExecutionEventFactory(filter, new CountingAndFunctionCallFactory());
}
public Map<UUID, FunctionCallInfo> registeredCalls() {
return calls;
} }
public Map<Class, List<Node>> assertNewNodes(String msg, int min, int max) { public Map<Class, List<Node>> assertNewNodes(String msg, int min, int max) {
@ -73,7 +93,7 @@ public class NodeCountingTestInstrument extends TruffleInstrument {
} }
} }
private final class CountingFactory implements ExecutionEventNodeFactory { private final class CountingAndFunctionCallFactory implements ExecutionEventNodeFactory {
@Override @Override
public ExecutionEventNode create(EventContext context) { public ExecutionEventNode create(EventContext context) {
final Node node = context.getInstrumentedNode(); final Node node = context.getInstrumentedNode();
@ -81,8 +101,92 @@ public class NodeCountingTestInstrument extends TruffleInstrument {
if (all.put(node, node) == null) { if (all.put(node, node) == null) {
counter.computeIfAbsent(node.getClass(), (__) -> new CopyOnWriteArrayList<>()).add(node); counter.computeIfAbsent(node.getClass(), (__) -> new CopyOnWriteArrayList<>()).add(node);
} }
return new NodeWrapper(context, calls);
} }
return null; return null;
} }
} }
private class NodeWrapper extends ExecutionEventNode {
private final EventContext context;
private final Map<UUID, FunctionCallInfo> calls;
public NodeWrapper(EventContext context, Map<UUID, FunctionCallInfo> calls) {
this.context = context;
this.calls = calls;
}
public void onReturnValue(VirtualFrame frame, Object result) {
Node node = context.getInstrumentedNode();
if (node instanceof FunctionCallInstrumentationNode instrumentableNode
&& result instanceof FunctionCallInstrumentationNode.FunctionCall functionCall) {
onFunctionReturn(instrumentableNode, functionCall);
}
}
private void onFunctionReturn(FunctionCallInstrumentationNode node, FunctionCallInstrumentationNode.FunctionCall result) {
if (node.getId() != null) {
calls.put(node.getId(), new FunctionCallInfo(result));
}
}
}
public class FunctionCallInfo {
private final QualifiedName moduleName;
private final QualifiedName typeName;
private final String functionName;
public FunctionCallInfo(FunctionCallInstrumentationNode.FunctionCall call) {
RootNode rootNode = call.getFunction().getCallTarget().getRootNode();
if (rootNode instanceof MethodRootNode methodNode) {
moduleName = methodNode.getModuleScope().getModule().getName();
typeName = methodNode.getType().getQualifiedName();
functionName = methodNode.getMethodName();
} else {
moduleName = null;
typeName = null;
functionName = rootNode.getName();
}
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
FunctionCallInfo that = (FunctionCallInfo) o;
return Objects.equals(moduleName, that.moduleName)
&& Objects.equals(typeName, that.typeName)
&& Objects.equals(functionName, that.functionName);
}
@Override
public int hashCode() {
return Objects.hash(moduleName, typeName, functionName);
}
@Override
public String toString() {
return moduleName + "::" + typeName + "::" + functionName;
}
public QualifiedName getModuleName() {
return moduleName;
}
public QualifiedName getTypeName() {
return typeName;
}
public String getFunctionName() {
return functionName;
}
}
} }

View File

@ -65,9 +65,7 @@ public class IncrementalUpdatesTest {
sendUpdatesWhenFunctionBodyIsChangedBySettingValue("4", ConstantsGen.INTEGER, "4", "5", "5", LiteralNode.class); sendUpdatesWhenFunctionBodyIsChangedBySettingValue("4", ConstantsGen.INTEGER, "4", "5", "5", LiteralNode.class);
var m = context.languageContext().findModule(MODULE_NAME).orElse(null); var m = context.languageContext().findModule(MODULE_NAME).orElse(null);
assertNotNull("Module found", m); assertNotNull("Module found", m);
var numbers = m.getIr().preorder().filter((v1) -> { var numbers = m.getIr().preorder().filter((v1) -> v1 instanceof IR$Literal$Number);
return v1 instanceof IR$Literal$Number;
});
assertEquals("One number found: " + numbers, 1, numbers.size()); assertEquals("One number found: " + numbers, 1, numbers.size());
if (numbers.head() instanceof IR$Literal$Number n) { if (numbers.head() instanceof IR$Literal$Number n) {
assertEquals("updated to 5", "5", n.value()); assertEquals("updated to 5", "5", n.value());

View File

@ -0,0 +1,92 @@
package org.enso.interpreter.test.instrument;
import com.oracle.truffle.api.instrumentation.SourceSectionFilter;
import com.oracle.truffle.api.instrumentation.StandardTags;
import org.enso.interpreter.runtime.tag.AvoidIdInstrumentationTag;
import org.enso.interpreter.runtime.tag.IdentifiedTag;
import org.enso.interpreter.test.Metadata;
import org.enso.interpreter.test.NodeCountingTestInstrument;
import org.enso.polyglot.RuntimeOptions;
import org.graalvm.polyglot.Context;
import org.graalvm.polyglot.Language;
import org.graalvm.polyglot.Source;
import static org.junit.Assert.assertEquals;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import java.io.OutputStream;
import java.nio.file.Paths;
import java.util.Map;
public class WarningInstrumentationTest {
private Context context;
private NodeCountingTestInstrument instrument;
@Before
public void initContext() {
context = Context.newBuilder()
.allowExperimentalOptions(true)
.option(
RuntimeOptions.LANGUAGE_HOME_OVERRIDE,
Paths.get("../../distribution/component").toFile().getAbsolutePath()
)
.logHandler(OutputStream.nullOutputStream())
.allowExperimentalOptions(true)
.allowIO(true)
.allowAllAccess(true)
.build();
var engine = context.getEngine();
Map<String, Language> langs = engine.getLanguages();
Assert.assertNotNull("Enso found: " + langs, langs.get("enso"));
instrument = engine.getInstruments().get(NodeCountingTestInstrument.INSTRUMENT_ID).lookup(NodeCountingTestInstrument.class);
SourceSectionFilter builder = SourceSectionFilter.newBuilder()
.tagIs(StandardTags.ExpressionTag.class, StandardTags.CallTag.class)
.tagIs(IdentifiedTag.class)
.tagIsNot(AvoidIdInstrumentationTag.class)
.build();
instrument.enable(builder);
}
@After
public void disposeContext() {
context.close();
}
@Test
public void instrumentValueWithWarnings() throws Exception {
var metadata = new Metadata();
var idOp1 = metadata.addItem(151, 34, null);
var idOp2 = metadata.addItem(202, 31, null);
var idOp3 = metadata.addItem(250, 13, null);
var rawCode = """
from Standard.Base import all
from Standard.Base.Warning import Warning
from Standard.Table.Data.Table import Table
run column_name =
operator1 = Table.new [[column_name, [1,2,3]]]
operator2 = Warning.attach "Text" operator1
operator3 = operator2.get
operator3
""";
var code = metadata.appendToCode(rawCode);
var src = Source.newBuilder("enso", code, "TestWarning.enso").build();
var module = context.eval(src);
var res = module.invokeMember("eval_expression", "run");
res.execute("A");
var calls = instrument.registeredCalls();
assertEquals(calls.keySet().size(), 3);
assertEquals(calls.get(idOp1).getFunctionName(), "new");
assertEquals(calls.get(idOp2).getFunctionName(), "attach");
assertEquals(calls.get(idOp3).getTypeName().item(), "Table");
assertEquals(calls.get(idOp3).getFunctionName(), "get");
}
}

View File

@ -142,13 +142,15 @@ public class FunctionCallInstrumentationNode extends Node implements Instrumenta
*/ */
@Override @Override
public WrapperNode createWrapper(ProbeNode probeNode) { public WrapperNode createWrapper(ProbeNode probeNode) {
return new FunctionCallInstrumentationNodeWrapper(this, probeNode); var wrapper = new FunctionCallInstrumentationNodeWrapper(this, probeNode);
wrapper.setId(this.getId());
return wrapper;
} }
/** /**
* Makrs this node with relevant runtime tags. * Marks this node with relevant runtime tags.
* *
* @param tag the tag to check agains. * @param tag the tag to check against.
* @return true if the node carries the {@code tag}, false otherwise. * @return true if the node carries the {@code tag}, false otherwise.
*/ */
@Override @Override

View File

@ -279,6 +279,7 @@ public abstract class InvokeCallableNode extends BaseNode {
invokeFunctionNode.getDefaultsExecutionMode(), invokeFunctionNode.getDefaultsExecutionMode(),
invokeFunctionNode.getArgumentsExecutionMode())); invokeFunctionNode.getArgumentsExecutionMode()));
childDispatch.setTailStatus(getTailStatus()); childDispatch.setTailStatus(getTailStatus());
childDispatch.setId(invokeFunctionNode.getId());
notifyInserted(childDispatch); notifyInserted(childDispatch);
} }
} finally { } finally {
@ -356,5 +357,8 @@ public abstract class InvokeCallableNode extends BaseNode {
invokeFunctionNode.setId(id); invokeFunctionNode.setId(id);
invokeMethodNode.setId(id); invokeMethodNode.setId(id);
invokeConversionNode.setId(id); invokeConversionNode.setId(id);
if (childDispatch != null) {
childDispatch.setId(id);
}
} }
} }

View File

@ -164,6 +164,7 @@ public abstract class InvokeConversionNode extends BaseNode {
invokeFunctionNode.getArgumentsExecutionMode(), invokeFunctionNode.getArgumentsExecutionMode(),
thatArgumentPosition)); thatArgumentPosition));
childDispatch.setTailStatus(getTailStatus()); childDispatch.setTailStatus(getTailStatus());
childDispatch.setId(invokeFunctionNode.getId());
notifyInserted(childDispatch); notifyInserted(childDispatch);
} }
} finally { } finally {

View File

@ -156,6 +156,7 @@ public abstract class InvokeMethodNode extends BaseNode {
invokeFunctionNode.getArgumentsExecutionMode(), invokeFunctionNode.getArgumentsExecutionMode(),
thisArgumentPosition)); thisArgumentPosition));
childDispatch.setTailStatus(getTailStatus()); childDispatch.setTailStatus(getTailStatus());
childDispatch.setId(invokeFunctionNode.getId());
notifyInserted(childDispatch); notifyInserted(childDispatch);
} }
} finally { } finally {
@ -525,5 +526,8 @@ public abstract class InvokeMethodNode extends BaseNode {
*/ */
public void setId(UUID id) { public void setId(UUID id) {
invokeFunctionNode.setId(id); invokeFunctionNode.setId(id);
if (childDispatch != null) {
childDispatch.setId(id);
}
} }
} }

View File

@ -189,4 +189,9 @@ public abstract class InvokeFunctionNode extends BaseNode {
public void setId(UUID id) { public void setId(UUID id) {
functionCallInstrumentationNode.setId(id); functionCallInstrumentationNode.setId(id);
} }
/** Returns expression ID of this node. */
public UUID getId() {
return functionCallInstrumentationNode.getId();
}
} }