Multi value Complex test and robustness refactoring (#11525)

While working on #11482 and enhancing the tests suite with more tests based on `type Complex` a [getRootNode() did not terminate in 100000 iterations](https://github.com/enso-org/enso/pull/11525#issuecomment-2476171597) problem was discovered. Detailed investigation revealed that the existing `ReadArgumentCheckNode` infrastructure was able to create a **cycle** of parent pointers in the Truffle AST.

The problem was in intricate manipulation of the AST while rewriting internals in `ReadArgumentCheckNode`. This PR avoids such manipulation by _refactoring the type checking code_. `ReadArgumentNode` knows nothing about types anymore. When a type check is needed, `IrToTruffle` adds additional `TypeCheckValueNode.wrap` around the `ReadArgumentNode` - that breaks the **vicious circle**.

All the _type checks_ nodes are moved to its own package. All but one of the classes are made package private. The external API for doing _type checking_ is concentrated into `TypeCheckValueNode`.
This commit is contained in:
Jaroslav Tulach 2024-11-19 18:04:42 +01:00 committed by somebody1234
parent 10e1b76f57
commit 73b93f5e6b
21 changed files with 862 additions and 663 deletions

View File

@ -1,578 +0,0 @@
package org.enso.interpreter.node.callable.argument;
import com.oracle.truffle.api.CompilerAsserts;
import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.TruffleLanguage;
import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.Cached.Shared;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.frame.MaterializedFrame;
import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.interop.InteropLibrary;
import com.oracle.truffle.api.interop.UnsupportedMessageException;
import com.oracle.truffle.api.nodes.ExplodeLoop;
import com.oracle.truffle.api.nodes.InvalidAssumptionException;
import com.oracle.truffle.api.nodes.Node;
import com.oracle.truffle.api.nodes.NodeUtil;
import com.oracle.truffle.api.nodes.RootNode;
import java.util.Arrays;
import java.util.List;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.enso.interpreter.EnsoLanguage;
import org.enso.interpreter.node.BaseNode.TailStatus;
import org.enso.interpreter.node.EnsoRootNode;
import org.enso.interpreter.node.ExpressionNode;
import org.enso.interpreter.node.callable.ApplicationNode;
import org.enso.interpreter.node.callable.InvokeCallableNode.DefaultsExecutionMode;
import org.enso.interpreter.node.callable.thunk.ThunkExecutorNode;
import org.enso.interpreter.node.expression.builtin.meta.AtomWithAHoleNode;
import org.enso.interpreter.node.expression.builtin.meta.IsValueOfTypeNode;
import org.enso.interpreter.node.expression.literal.LiteralNode;
import org.enso.interpreter.runtime.EnsoContext;
import org.enso.interpreter.runtime.callable.UnresolvedConstructor;
import org.enso.interpreter.runtime.callable.UnresolvedConversion;
import org.enso.interpreter.runtime.callable.argument.ArgumentDefinition;
import org.enso.interpreter.runtime.callable.argument.ArgumentDefinition.ExecutionMode;
import org.enso.interpreter.runtime.callable.argument.CallArgument;
import org.enso.interpreter.runtime.callable.function.Function;
import org.enso.interpreter.runtime.callable.function.FunctionSchema;
import org.enso.interpreter.runtime.data.EnsoMultiValue;
import org.enso.interpreter.runtime.data.Type;
import org.enso.interpreter.runtime.data.text.Text;
import org.enso.interpreter.runtime.error.DataflowError;
import org.enso.interpreter.runtime.error.PanicException;
import org.enso.interpreter.runtime.error.PanicSentinel;
import org.enso.interpreter.runtime.library.dispatch.TypeOfNode;
import org.enso.interpreter.runtime.library.dispatch.TypesLibrary;
import org.enso.interpreter.runtime.util.CachingSupplier;
import org.graalvm.collections.Pair;
public abstract class ReadArgumentCheckNode extends Node {
private final String comment;
@CompilerDirectives.CompilationFinal private String expectedTypeMessage;
ReadArgumentCheckNode(String comment) {
this.comment = comment;
}
/** */
public static ExpressionNode wrap(ExpressionNode original, ReadArgumentCheckNode check) {
return new TypeCheckExpressionNode(original, check);
}
/**
* Executes check or conversion of the value.
*
* @param frame frame requesting the conversion
* @param value the value to convert
* @return {@code null} when the check isn't satisfied and conversion isn't possible or non-{@code
* null} value that can be used as a result
*/
public final Object handleCheckOrConversion(VirtualFrame frame, Object value) {
var result = executeCheckOrConversion(frame, value);
if (result == null) {
throw panicAtTheEnd(value);
}
return result;
}
abstract Object findDirectMatch(VirtualFrame frame, Object value);
abstract Object executeCheckOrConversion(VirtualFrame frame, Object value);
abstract String expectedTypeMessage();
protected final String joinTypeParts(List<String> parts, String separator) {
assert !parts.isEmpty();
if (parts.size() == 1) {
return parts.get(0);
}
var separatorWithSpace = " " + separator + " ";
var builder = new StringBuilder();
boolean isFirst = true;
for (String part : parts) {
if (isFirst) {
isFirst = false;
} else {
builder.append(separatorWithSpace);
}
// If the part contains a space, it means it is not a single type but already a more complex
// expression with a separator.
// So to ensure we don't mess up the expression layers, we need to add parentheses around it.
boolean needsParentheses = part.contains(" ");
if (needsParentheses) {
builder.append("(");
}
builder.append(part);
if (needsParentheses) {
builder.append(")");
}
}
return builder.toString();
}
final PanicException panicAtTheEnd(Object v) {
if (expectedTypeMessage == null) {
CompilerDirectives.transferToInterpreterAndInvalidate();
expectedTypeMessage = expectedTypeMessage();
}
var ctx = EnsoContext.get(this);
Text msg;
if (v instanceof UnresolvedConstructor) {
msg = Text.create("Cannot find constructor {got} among {exp}");
} else {
var where = Text.create(comment == null ? "expression" : comment);
var exp = Text.create("expected ");
var got = Text.create(" to be {exp}, but got {got}");
msg = Text.create(exp, Text.create(where, got));
}
var err = ctx.getBuiltins().error().makeTypeErrorOfComment(expectedTypeMessage, v, msg);
throw new PanicException(err, this);
}
public static ReadArgumentCheckNode allOf(String argumentName, ReadArgumentCheckNode... checks) {
var list = Arrays.asList(checks);
var flatten =
list.stream()
.flatMap(
n -> n instanceof AllOfNode all ? Arrays.asList(all.checks).stream() : Stream.of(n))
.toList();
var arr = toArray(flatten);
return switch (arr.length) {
case 0 -> null;
case 1 -> arr[0];
default -> new AllOfNode(argumentName, arr);
};
}
public static ReadArgumentCheckNode oneOf(String comment, List<ReadArgumentCheckNode> checks) {
var arr = toArray(checks);
return switch (arr.length) {
case 0 -> null;
case 1 -> arr[0];
default -> new OneOfNode(comment, arr);
};
}
public static ReadArgumentCheckNode build(EnsoContext ctx, String comment, Type expectedType) {
assert ctx.getBuiltins().any() != expectedType : "Don't check for Any: " + expectedType;
return ReadArgumentCheckNodeFactory.TypeCheckNodeGen.create(comment, expectedType);
}
public static ReadArgumentCheckNode meta(
String comment, Supplier<? extends Object> metaObjectSupplier) {
var cachingSupplier = CachingSupplier.wrap(metaObjectSupplier);
return ReadArgumentCheckNodeFactory.MetaCheckNodeGen.create(comment, cachingSupplier);
}
public static boolean isWrappedThunk(Function fn) {
if (fn.getSchema() == LazyCheckRootNode.SCHEMA) {
return fn.getPreAppliedArguments()[0] instanceof Function wrappedFn && wrappedFn.isThunk();
}
return false;
}
private static boolean isAllFitValue(Object v) {
return v instanceof DataflowError || AtomWithAHoleNode.isHole(v);
}
private static ReadArgumentCheckNode[] toArray(List<ReadArgumentCheckNode> list) {
if (list == null) {
return new ReadArgumentCheckNode[0];
}
var cnt = (int) list.stream().filter(n -> n != null).count();
var arr = new ReadArgumentCheckNode[cnt];
var it = list.iterator();
for (int i = 0; i < cnt; ) {
var element = it.next();
if (element != null) {
arr[i++] = element;
}
}
return arr;
}
static final class AllOfNode extends ReadArgumentCheckNode {
@Children private ReadArgumentCheckNode[] checks;
@Child private TypesLibrary types;
AllOfNode(String name, ReadArgumentCheckNode[] checks) {
super(name);
this.checks = checks;
this.types = TypesLibrary.getFactory().createDispatched(checks.length);
}
@Override
Object findDirectMatch(VirtualFrame frame, Object value) {
return null;
}
@Override
@ExplodeLoop
Object executeCheckOrConversion(VirtualFrame frame, Object value) {
var values = new Object[checks.length];
var valueTypes = new Type[checks.length];
var at = 0;
for (var n : checks) {
var result = n.executeCheckOrConversion(frame, value);
if (result == null) {
return null;
}
values[at] = result;
valueTypes[at] = types.getType(result);
at++;
}
return EnsoMultiValue.create(valueTypes, values);
}
@Override
String expectedTypeMessage() {
var parts =
Arrays.stream(checks)
.map(ReadArgumentCheckNode::expectedTypeMessage)
.collect(Collectors.toList());
return joinTypeParts(parts, "&");
}
}
static final class OneOfNode extends ReadArgumentCheckNode {
@Children private ReadArgumentCheckNode[] checks;
OneOfNode(String name, ReadArgumentCheckNode[] checks) {
super(name);
this.checks = checks;
}
@Override
@ExplodeLoop
final Object findDirectMatch(VirtualFrame frame, Object value) {
for (var n : checks) {
var result = n.findDirectMatch(frame, value);
if (result != null) {
return result;
}
}
return null;
}
@Override
@ExplodeLoop
Object executeCheckOrConversion(VirtualFrame frame, Object value) {
var direct = findDirectMatch(frame, value);
if (direct != null) {
return direct;
}
for (var n : checks) {
var result = n.executeCheckOrConversion(frame, value);
if (result != null) {
return result;
}
}
return null;
}
@Override
String expectedTypeMessage() {
var parts =
Arrays.stream(checks)
.map(ReadArgumentCheckNode::expectedTypeMessage)
.collect(Collectors.toList());
return joinTypeParts(parts, "|");
}
}
abstract static class TypeCheckNode extends ReadArgumentCheckNode {
private final Type expectedType;
@Child IsValueOfTypeNode checkType;
@CompilerDirectives.CompilationFinal private String expectedTypeMessage;
@CompilerDirectives.CompilationFinal private LazyCheckRootNode lazyCheck;
@Child private EnsoMultiValue.CastToNode castTo;
TypeCheckNode(String name, Type expectedType) {
super(name);
this.checkType = IsValueOfTypeNode.build();
this.expectedType = expectedType;
}
@Specialization
Object doPanicSentinel(VirtualFrame frame, PanicSentinel panicSentinel) {
throw panicSentinel;
}
@Specialization
Object doUnresolvedConstructor(
VirtualFrame frame,
UnresolvedConstructor unresolved,
@Cached UnresolvedConstructor.ConstructNode construct) {
var state = Function.ArgumentsHelper.getState(frame.getArguments());
return construct.execute(frame, state, expectedType, unresolved);
}
@Specialization(rewriteOn = InvalidAssumptionException.class)
Object doCheckNoConversionNeeded(VirtualFrame frame, Object v)
throws InvalidAssumptionException {
var ret = findDirectMatch(frame, v);
if (ret != null) {
return ret;
} else {
throw new InvalidAssumptionException();
}
}
@Specialization(
limit = "10",
guards = {"cachedType != null", "findType(typeOfNode, v, cachedType) == cachedType"})
Object doWithConversionCached(
VirtualFrame frame,
Object v,
@Shared("typeOfNode") @Cached TypeOfNode typeOfNode,
@Cached(value = "findType(typeOfNode, v)", dimensions = 1) Type[] cachedType,
@Cached("findConversionNode(cachedType)") ApplicationNode convertNode) {
return handleWithConversion(frame, v, convertNode);
}
@Specialization(replaces = "doWithConversionCached")
Object doWithConversionUncached(
VirtualFrame frame, Object v, @Shared("typeOfNode") @Cached TypeOfNode typeOfNode) {
var type = findType(typeOfNode, v);
return doWithConversionUncachedBoundary(frame == null ? null : frame.materialize(), v, type);
}
@ExplodeLoop
final Object findDirectMatch(VirtualFrame frame, Object v) {
if (isAllFitValue(v)) {
return v;
}
if (v instanceof Function fn && fn.isThunk()) {
if (lazyCheck == null) {
CompilerDirectives.transferToInterpreter();
var enso = EnsoLanguage.get(this);
var node = (ReadArgumentCheckNode) copy();
lazyCheck = new LazyCheckRootNode(enso, node);
}
var lazyCheckFn = lazyCheck.wrapThunk(fn);
return lazyCheckFn;
}
if (v instanceof EnsoMultiValue mv) {
if (castTo == null) {
CompilerDirectives.transferToInterpreter();
castTo = insert(EnsoMultiValue.CastToNode.create());
}
var result = castTo.executeCast(expectedType, mv);
if (result != null) {
return result;
}
}
if (checkType.execute(expectedType, v)) {
return v;
}
return null;
}
private Pair<Function, Type> findConversion(Type from) {
if (expectedType == from) {
return null;
}
var ctx = EnsoContext.get(this);
if (getRootNode() instanceof EnsoRootNode root) {
var convert = UnresolvedConversion.build(root.getModuleScope());
var conv = convert.resolveFor(ctx, expectedType, from);
if (conv != null) {
return Pair.create(conv, expectedType);
}
}
return null;
}
ApplicationNode findConversionNode(Type[] allTypes) {
if (allTypes == null) {
allTypes = new Type[] {null};
}
for (var from : allTypes) {
var convAndType = findConversion(from);
if (convAndType != null) {
if (NodeUtil.findParent(this, ReadArgumentNode.class) instanceof ReadArgumentNode ran) {
CompilerAsserts.neverPartOfCompilation();
var convNode = LiteralNode.build(convAndType.getLeft());
var intoNode = LiteralNode.build(convAndType.getRight());
var valueNode = ran.plainRead();
var args =
new CallArgument[] {
new CallArgument(null, intoNode), new CallArgument(null, valueNode)
};
return ApplicationNode.build(convNode, args, DefaultsExecutionMode.EXECUTE);
} else if (NodeUtil.findParent(this, TypeCheckExpressionNode.class)
instanceof TypeCheckExpressionNode tcen) {
CompilerAsserts.neverPartOfCompilation();
var convNode = LiteralNode.build(convAndType.getLeft());
var intoNode = LiteralNode.build(convAndType.getRight());
var valueNode = tcen.original;
var args =
new CallArgument[] {
new CallArgument(null, intoNode), new CallArgument(null, valueNode)
};
return ApplicationNode.build(convNode, args, DefaultsExecutionMode.EXECUTE);
}
}
}
return null;
}
Type[] findType(TypeOfNode typeOfNode, Object v) {
return findType(typeOfNode, v, null);
}
Type[] findType(TypeOfNode typeOfNode, Object v, Type[] previous) {
if (v instanceof EnsoMultiValue multi) {
return multi.allTypes();
}
if (v instanceof UnresolvedConstructor) {
return null;
}
if (typeOfNode.execute(v) instanceof Type from) {
if (previous != null && previous.length == 1 && previous[0] == from) {
return previous;
} else {
return new Type[] {from};
}
}
return null;
}
private Object handleWithConversion(VirtualFrame frame, Object v, ApplicationNode convertNode)
throws PanicException {
if (convertNode == null) {
var ret = findDirectMatch(frame, v);
if (ret != null) {
return ret;
}
return null;
} else {
var converted = convertNode.executeGeneric(frame);
return converted;
}
}
@CompilerDirectives.TruffleBoundary
private Object doWithConversionUncachedBoundary(
MaterializedFrame frame, Object v, Type[] type) {
var convertNode = findConversionNode(type);
return handleWithConversion(frame, v, convertNode);
}
@Override
String expectedTypeMessage() {
if (expectedTypeMessage != null) {
return expectedTypeMessage;
}
CompilerDirectives.transferToInterpreterAndInvalidate();
expectedTypeMessage = expectedType.toString();
return expectedTypeMessage;
}
}
abstract static class MetaCheckNode extends ReadArgumentCheckNode {
private final CachingSupplier<? extends Object> expectedSupplier;
@CompilerDirectives.CompilationFinal private String expectedTypeMessage;
MetaCheckNode(String name, CachingSupplier<? extends Object> expectedMetaSupplier) {
super(name);
this.expectedSupplier = expectedMetaSupplier;
}
@Override
Object findDirectMatch(VirtualFrame frame, Object value) {
return executeCheckOrConversion(frame, value);
}
@Specialization()
Object verifyMetaObject(VirtualFrame frame, Object v, @Cached IsValueOfTypeNode isA) {
if (isAllFitValue(v)) {
return v;
}
if (isA.execute(expectedSupplier.get(), v)) {
return v;
} else {
return null;
}
}
@Override
String expectedTypeMessage() {
if (expectedTypeMessage != null) {
return expectedTypeMessage;
}
CompilerDirectives.transferToInterpreterAndInvalidate();
var iop = InteropLibrary.getUncached();
try {
expectedTypeMessage = iop.asString(iop.getMetaQualifiedName(expectedSupplier.get()));
} catch (UnsupportedMessageException ex) {
expectedTypeMessage = expectedSupplier.get().toString();
}
return expectedTypeMessage;
}
}
private static final class LazyCheckRootNode extends RootNode {
@Child private ThunkExecutorNode evalThunk;
@Child private ReadArgumentCheckNode check;
static final FunctionSchema SCHEMA =
FunctionSchema.newBuilder()
.argumentDefinitions(
new ArgumentDefinition(0, "delegate", null, null, ExecutionMode.EXECUTE))
.hasPreapplied(true)
.build();
LazyCheckRootNode(TruffleLanguage<?> language, ReadArgumentCheckNode check) {
super(language);
this.check = check;
this.evalThunk = ThunkExecutorNode.build();
}
Function wrapThunk(Function thunk) {
return new Function(getCallTarget(), thunk.getScope(), SCHEMA, new Object[] {thunk}, null);
}
@Override
public Object execute(VirtualFrame frame) {
var state = Function.ArgumentsHelper.getState(frame.getArguments());
var args = Function.ArgumentsHelper.getPositionalArguments(frame.getArguments());
assert args.length == 1;
assert args[0] instanceof Function fn && fn.isThunk();
var raw = evalThunk.executeThunk(frame, args[0], state, TailStatus.NOT_TAIL);
var result = check.handleCheckOrConversion(frame, raw);
return result;
}
}
private static final class TypeCheckExpressionNode extends ExpressionNode {
@Child private ExpressionNode original;
@Child private ReadArgumentCheckNode check;
TypeCheckExpressionNode(ExpressionNode original, ReadArgumentCheckNode check) {
this.check = check;
this.original = original;
}
@Override
public Object executeGeneric(VirtualFrame frame) {
var value = original.executeGeneric(frame);
var result = check.handleCheckOrConversion(frame, value);
return result;
}
@Override
public boolean isInstrumentable() {
return false;
}
}
}

View File

@ -14,13 +14,11 @@ import org.enso.interpreter.runtime.callable.function.Function;
public final class ReadArgumentNode extends ExpressionNode { public final class ReadArgumentNode extends ExpressionNode {
private final int index; private final int index;
@Child ExpressionNode defaultValue; @Child ExpressionNode defaultValue;
@Child ReadArgumentCheckNode checkType;
private final CountingConditionProfile defaultingProfile = CountingConditionProfile.create(); private final CountingConditionProfile defaultingProfile = CountingConditionProfile.create();
private ReadArgumentNode(int position, ExpressionNode defaultValue, ReadArgumentCheckNode check) { private ReadArgumentNode(int position, ExpressionNode defaultValue) {
this.index = position; this.index = position;
this.defaultValue = defaultValue; this.defaultValue = defaultValue;
this.checkType = check;
} }
/** /**
@ -28,18 +26,10 @@ public final class ReadArgumentNode extends ExpressionNode {
* *
* @param position the argument's position at the definition site * @param position the argument's position at the definition site
* @param defaultValue the default value provided for that argument * @param defaultValue the default value provided for that argument
* @param check {@code null} or node to check type of input
* @return a node representing the argument at position {@code idx} * @return a node representing the argument at position {@code idx}
*/ */
public static ReadArgumentNode build( public static ReadArgumentNode build(int position, ExpressionNode defaultValue) {
int position, ExpressionNode defaultValue, ReadArgumentCheckNode check) { return new ReadArgumentNode(position, defaultValue);
return new ReadArgumentNode(position, defaultValue, check);
}
ReadArgumentNode plainRead() {
var node = (ReadArgumentNode) this.copy();
node.checkType = null;
return node;
} }
/** /**
@ -68,9 +58,6 @@ public final class ReadArgumentNode extends ExpressionNode {
v = arguments[index]; v = arguments[index];
} }
} }
if (checkType != null) {
v = checkType.handleCheckOrConversion(frame, v);
}
return v; return v;
} }

View File

@ -0,0 +1,93 @@
package org.enso.interpreter.node.typecheck;
import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.nodes.Node;
import java.util.List;
import org.enso.interpreter.node.ExpressionNode;
import org.enso.interpreter.node.expression.builtin.meta.AtomWithAHoleNode;
import org.enso.interpreter.runtime.data.text.Text;
import org.enso.interpreter.runtime.error.DataflowError;
/**
* Root of hierarchy of nodes checking types. This class (and its subclasses) are an implementation
* detail. The API to perform the is in {@link TypeCheckNode}.
*/
abstract sealed class AbstractTypeCheckNode extends Node
permits OneOfTypesCheckNode, AllOfTypesCheckNode, SingleTypeCheckNode, MetaTypeCheckNode {
private final String comment;
@CompilerDirectives.CompilationFinal private String expectedTypeMessage;
AbstractTypeCheckNode(String comment) {
this.comment = comment;
}
abstract Object findDirectMatch(VirtualFrame frame, Object value);
abstract Object executeCheckOrConversion(
VirtualFrame frame, Object value, ExpressionNode valueNode);
abstract String expectedTypeMessage();
/**
* The error message for this node's check. Ready for being used at "fast path".
*
* @return
*/
final String getExpectedTypeMessage() {
if (expectedTypeMessage == null) {
CompilerDirectives.transferToInterpreterAndInvalidate();
expectedTypeMessage = expectedTypeMessage();
}
return expectedTypeMessage;
}
/**
* Composes the comment message describing the kind of check this node performs. Ready for being
* used at "fast path".
*
* @return description of this node's "expectations"
*/
final Text getComment() {
var where = Text.create(comment == null ? "expression" : comment);
var exp = Text.create("expected ");
var got = Text.create(" to be {exp}, but got {got}");
return Text.create(exp, Text.create(where, got));
}
final String joinTypeParts(List<String> parts, String separator) {
assert !parts.isEmpty();
if (parts.size() == 1) {
return parts.get(0);
}
var separatorWithSpace = " " + separator + " ";
var builder = new StringBuilder();
boolean isFirst = true;
for (String part : parts) {
if (isFirst) {
isFirst = false;
} else {
builder.append(separatorWithSpace);
}
// If the part contains a space, it means it is not a single type but already a more complex
// expression with a separator.
// So to ensure we don't mess up the expression layers, we need to add parentheses around it.
boolean needsParentheses = part.contains(" ");
if (needsParentheses) {
builder.append("(");
}
builder.append(part);
if (needsParentheses) {
builder.append(")");
}
}
return builder.toString();
}
static boolean isAllFitValue(Object v) {
return v instanceof DataflowError || AtomWithAHoleNode.isHole(v);
}
}

View File

@ -0,0 +1,58 @@
package org.enso.interpreter.node.typecheck;
import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.nodes.ExplodeLoop;
import java.util.Arrays;
import java.util.stream.Collectors;
import org.enso.interpreter.node.ExpressionNode;
import org.enso.interpreter.runtime.data.EnsoMultiValue;
import org.enso.interpreter.runtime.data.Type;
import org.enso.interpreter.runtime.library.dispatch.TypesLibrary;
final class AllOfTypesCheckNode extends AbstractTypeCheckNode {
@Children private AbstractTypeCheckNode[] checks;
@Child private TypesLibrary types;
AllOfTypesCheckNode(String name, AbstractTypeCheckNode[] checks) {
super(name);
this.checks = checks;
this.types = TypesLibrary.getFactory().createDispatched(checks.length);
}
AbstractTypeCheckNode[] getChecks() {
return checks;
}
@Override
Object findDirectMatch(VirtualFrame frame, Object value) {
return null;
}
@Override
@ExplodeLoop
Object executeCheckOrConversion(VirtualFrame frame, Object value, ExpressionNode expr) {
var values = new Object[checks.length];
var valueTypes = new Type[checks.length];
var at = 0;
for (var n : checks) {
var result = n.executeCheckOrConversion(frame, value, expr);
if (result == null) {
return null;
}
values[at] = result;
valueTypes[at] = types.getType(result);
at++;
}
return EnsoMultiValue.create(valueTypes, values);
}
@Override
String expectedTypeMessage() {
var parts =
Arrays.stream(checks)
.map(AbstractTypeCheckNode::expectedTypeMessage)
.collect(Collectors.toList());
return joinTypeParts(parts, "&");
}
}

View File

@ -0,0 +1,44 @@
package org.enso.interpreter.node.typecheck;
import com.oracle.truffle.api.TruffleLanguage;
import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.nodes.RootNode;
import org.enso.interpreter.node.BaseNode;
import org.enso.interpreter.node.callable.thunk.ThunkExecutorNode;
import org.enso.interpreter.runtime.callable.argument.ArgumentDefinition;
import org.enso.interpreter.runtime.callable.function.Function;
import org.enso.interpreter.runtime.callable.function.FunctionSchema;
final class LazyCheckRootNode extends RootNode {
@Child private ThunkExecutorNode evalThunk;
@Child private TypeCheckValueNode check;
static final FunctionSchema SCHEMA =
FunctionSchema.newBuilder()
.argumentDefinitions(
new ArgumentDefinition(
0, "delegate", null, null, ArgumentDefinition.ExecutionMode.EXECUTE))
.hasPreapplied(true)
.build();
LazyCheckRootNode(TruffleLanguage<?> language, TypeCheckValueNode check) {
super(language);
this.check = check;
this.evalThunk = ThunkExecutorNode.build();
}
Function wrapThunk(Function thunk) {
return new Function(getCallTarget(), thunk.getScope(), SCHEMA, new Object[] {thunk}, null);
}
@Override
public Object execute(VirtualFrame frame) {
var state = Function.ArgumentsHelper.getState(frame.getArguments());
var args = Function.ArgumentsHelper.getPositionalArguments(frame.getArguments());
assert args.length == 1;
assert args[0] instanceof Function fn && fn.isThunk();
var raw = evalThunk.executeThunk(frame, args[0], state, BaseNode.TailStatus.NOT_TAIL);
var result = check.handleCheckOrConversion(frame, raw, null);
return result;
}
}

View File

@ -0,0 +1,58 @@
package org.enso.interpreter.node.typecheck;
import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.interop.InteropLibrary;
import com.oracle.truffle.api.interop.UnsupportedMessageException;
import org.enso.interpreter.node.expression.builtin.meta.IsValueOfTypeNode;
import org.enso.interpreter.runtime.util.CachingSupplier;
/**
* Node for checking {@code polyglot java import} types. This class (and its subclasses)
* are an implementation detail. The API to perform the is in {@link TypeCheckNode}.
*/
non-sealed abstract class MetaTypeCheckNode extends AbstractTypeCheckNode {
private final CachingSupplier<? extends Object> expectedSupplier;
@CompilerDirectives.CompilationFinal private String expectedTypeMessage;
MetaTypeCheckNode(String name, CachingSupplier<? extends Object> expectedMetaSupplier) {
super(name);
this.expectedSupplier = expectedMetaSupplier;
}
abstract Object executeCheckOrConversion(VirtualFrame frame, Object value);
@Override
Object findDirectMatch(VirtualFrame frame, Object value) {
return executeCheckOrConversion(frame, value);
}
@Specialization
Object verifyMetaObject(VirtualFrame frame, Object v, @Cached IsValueOfTypeNode isA) {
if (isAllFitValue(v)) {
return v;
}
if (isA.execute(expectedSupplier.get(), v)) {
return v;
} else {
return null;
}
}
@Override
String expectedTypeMessage() {
if (expectedTypeMessage != null) {
return expectedTypeMessage;
}
CompilerDirectives.transferToInterpreterAndInvalidate();
com.oracle.truffle.api.interop.InteropLibrary iop = InteropLibrary.getUncached();
try {
expectedTypeMessage = iop.asString(iop.getMetaQualifiedName(expectedSupplier.get()));
} catch (UnsupportedMessageException ex) {
expectedTypeMessage = expectedSupplier.get().toString();
}
return expectedTypeMessage;
}
}

View File

@ -0,0 +1,54 @@
package org.enso.interpreter.node.typecheck;
import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.nodes.ExplodeLoop;
import java.util.Arrays;
import java.util.stream.Collectors;
import org.enso.interpreter.node.ExpressionNode;
final class OneOfTypesCheckNode extends AbstractTypeCheckNode {
@Children private AbstractTypeCheckNode[] checks;
OneOfTypesCheckNode(String name, AbstractTypeCheckNode[] checks) {
super(name);
this.checks = checks;
}
@Override
@ExplodeLoop
final Object findDirectMatch(VirtualFrame frame, Object value) {
for (var n : checks) {
java.lang.Object result = n.findDirectMatch(frame, value);
if (result != null) {
return result;
}
}
return null;
}
@Override
@ExplodeLoop
Object executeCheckOrConversion(VirtualFrame frame, Object value, ExpressionNode expr) {
java.lang.Object direct = findDirectMatch(frame, value);
if (direct != null) {
return direct;
}
for (var n : checks) {
java.lang.Object result = n.executeCheckOrConversion(frame, value, expr);
if (result != null) {
return result;
}
}
return null;
}
@Override
String expectedTypeMessage() {
java.util.List<java.lang.String> parts =
Arrays.stream(checks)
.map(AbstractTypeCheckNode::expectedTypeMessage)
.collect(Collectors.toList());
return joinTypeParts(parts, "|");
}
}

View File

@ -0,0 +1,221 @@
package org.enso.interpreter.node.typecheck;
import com.oracle.truffle.api.CompilerAsserts;
import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.frame.MaterializedFrame;
import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.nodes.ExplodeLoop;
import com.oracle.truffle.api.nodes.InvalidAssumptionException;
import com.oracle.truffle.api.nodes.Node;
import org.enso.interpreter.EnsoLanguage;
import org.enso.interpreter.node.EnsoRootNode;
import org.enso.interpreter.node.ExpressionNode;
import org.enso.interpreter.node.callable.ApplicationNode;
import org.enso.interpreter.node.callable.InvokeCallableNode;
import org.enso.interpreter.node.expression.builtin.meta.IsValueOfTypeNode;
import org.enso.interpreter.node.expression.literal.LiteralNode;
import org.enso.interpreter.runtime.EnsoContext;
import org.enso.interpreter.runtime.callable.UnresolvedConstructor;
import org.enso.interpreter.runtime.callable.UnresolvedConversion;
import org.enso.interpreter.runtime.callable.argument.CallArgument;
import org.enso.interpreter.runtime.callable.function.Function;
import org.enso.interpreter.runtime.data.EnsoMultiValue;
import org.enso.interpreter.runtime.data.Type;
import org.enso.interpreter.runtime.error.PanicException;
import org.enso.interpreter.runtime.error.PanicSentinel;
import org.enso.interpreter.runtime.library.dispatch.TypeOfNode;
import org.graalvm.collections.Pair;
non-sealed abstract class SingleTypeCheckNode extends AbstractTypeCheckNode {
private final Type expectedType;
@Node.Child IsValueOfTypeNode checkType;
@CompilerDirectives.CompilationFinal private String expectedTypeMessage;
@CompilerDirectives.CompilationFinal private LazyCheckRootNode lazyCheck;
@Node.Child private EnsoMultiValue.CastToNode castTo;
SingleTypeCheckNode(String name, Type expectedType) {
super(name);
this.checkType = IsValueOfTypeNode.build();
this.expectedType = expectedType;
}
abstract Object executeCheckOrConversion(
VirtualFrame frame, Object value, ExpressionNode valueSource);
@Specialization
Object doPanicSentinel(VirtualFrame frame, PanicSentinel panicSentinel, ExpressionNode ignore) {
throw panicSentinel;
}
@Specialization
Object doUnresolvedConstructor(
VirtualFrame frame,
UnresolvedConstructor unresolved,
ExpressionNode ignore,
@Cached UnresolvedConstructor.ConstructNode construct) {
var state = Function.ArgumentsHelper.getState(frame.getArguments());
return construct.execute(frame, state, expectedType, unresolved);
}
@Specialization(rewriteOn = InvalidAssumptionException.class)
Object doCheckNoConversionNeeded(VirtualFrame frame, Object v, ExpressionNode ignore)
throws InvalidAssumptionException {
var ret = findDirectMatch(frame, v);
if (ret != null) {
return ret;
} else {
throw new InvalidAssumptionException();
}
}
@Specialization(
limit = "10",
guards = {"cachedType != null", "findType(typeOfNode, v, cachedType) == cachedType"})
Object doWithConversionCached(
VirtualFrame frame,
Object v,
ExpressionNode valueSource,
@Cached.Shared("typeOfNode") @Cached TypeOfNode typeOfNode,
@Cached(value = "findType(typeOfNode, v)", dimensions = 1) Type[] cachedType,
@Cached("findConversionNode(valueSource, cachedType)") ApplicationNode convertNode) {
return handleWithConversion(frame, v, convertNode);
}
@Specialization(replaces = "doWithConversionCached")
Object doWithConversionUncached(
VirtualFrame frame,
Object v,
ExpressionNode expr,
@Cached.Shared("typeOfNode") @Cached TypeOfNode typeOfNode) {
var type = findType(typeOfNode, v);
return doWithConversionUncachedBoundary(
frame == null ? null : frame.materialize(), v, expr, type);
}
@ExplodeLoop
final Object findDirectMatch(VirtualFrame frame, Object v) {
if (isAllFitValue(v)) {
return v;
}
if (v instanceof Function fn && fn.isThunk()) {
if (lazyCheck == null) {
CompilerDirectives.transferToInterpreter();
var enso = EnsoLanguage.get(this);
var node = (AbstractTypeCheckNode) copy();
lazyCheck = new LazyCheckRootNode(enso, new TypeCheckValueNode(node));
}
var lazyCheckFn = lazyCheck.wrapThunk(fn);
return lazyCheckFn;
}
assert EnsoContext.get(this).getBuiltins().any() != expectedType : "Don't check for Any: " + expectedType;
if (v instanceof EnsoMultiValue mv) {
if (castTo == null) {
CompilerDirectives.transferToInterpreter();
castTo = insert(EnsoMultiValue.CastToNode.create());
}
var result = castTo.executeCast(expectedType, mv);
if (result != null) {
return result;
}
}
if (checkType.execute(expectedType, v)) {
return v;
}
return null;
}
private Pair<Function, Type> findConversion(Type from) {
if (expectedType == from) {
return null;
}
var ctx = EnsoContext.get(this);
if (getRootNode() instanceof EnsoRootNode root) {
var convert = UnresolvedConversion.build(root.getModuleScope());
var conv = convert.resolveFor(ctx, expectedType, from);
if (conv != null) {
return Pair.create(conv, expectedType);
}
}
return null;
}
ApplicationNode findConversionNode(ExpressionNode valueNode, Type[] allTypes) {
if (valueNode == null) {
return null;
}
if (allTypes == null) {
allTypes = new Type[] {null};
}
for (var from : allTypes) {
var convAndType = findConversion(from);
if (convAndType != null) {
CompilerAsserts.neverPartOfCompilation();
var convNode = LiteralNode.build(convAndType.getLeft());
var intoNode = LiteralNode.build(convAndType.getRight());
var args =
new CallArgument[] {
new CallArgument(null, intoNode), new CallArgument(null, valueNode)
};
return ApplicationNode.build(
convNode, args, InvokeCallableNode.DefaultsExecutionMode.EXECUTE);
}
}
return null;
}
Type[] findType(TypeOfNode typeOfNode, Object v) {
return findType(typeOfNode, v, null);
}
Type[] findType(TypeOfNode typeOfNode, Object v, Type[] previous) {
if (v instanceof EnsoMultiValue multi) {
return multi.allTypes();
}
if (v instanceof UnresolvedConstructor) {
return null;
}
if (typeOfNode.execute(v) instanceof Type from) {
if (previous != null && previous.length == 1 && previous[0] == from) {
return previous;
} else {
return new Type[] {from};
}
}
return null;
}
private Object handleWithConversion(VirtualFrame frame, Object v, ApplicationNode convertNode)
throws PanicException {
if (convertNode == null) {
var ret = findDirectMatch(frame, v);
if (ret != null) {
return ret;
}
return null;
} else {
var converted = convertNode.executeGeneric(frame);
return converted;
}
}
@CompilerDirectives.TruffleBoundary
private Object doWithConversionUncachedBoundary(
MaterializedFrame frame, Object v, ExpressionNode expr, Type[] type) {
var convertNode = findConversionNode(expr, type);
return handleWithConversion(frame, v, convertNode);
}
@Override
String expectedTypeMessage() {
if (expectedTypeMessage != null) {
return expectedTypeMessage;
}
CompilerDirectives.transferToInterpreterAndInvalidate();
expectedTypeMessage = expectedType.toString();
return expectedTypeMessage;
}
}

View File

@ -0,0 +1,31 @@
package org.enso.interpreter.node.typecheck;
import com.oracle.truffle.api.frame.VirtualFrame;
import org.enso.interpreter.node.ExpressionNode;
final class TypeCheckExpressionNode extends ExpressionNode {
@Child private ExpressionNode original;
@Child private TypeCheckValueNode check;
TypeCheckExpressionNode(ExpressionNode original, TypeCheckValueNode check) {
this.check = check;
this.original = original;
}
ExpressionNode getOriginal() {
return original;
}
@Override
public Object executeGeneric(VirtualFrame frame) {
java.lang.Object value = original.executeGeneric(frame);
java.lang.Object result = check.handleCheckOrConversion(frame, value, original);
return result;
}
@Override
public boolean isInstrumentable() {
return false;
}
}

View File

@ -0,0 +1,182 @@
package org.enso.interpreter.node.typecheck;
import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.nodes.Node;
import java.util.Arrays;
import java.util.List;
import java.util.function.Supplier;
import java.util.stream.Stream;
import org.enso.interpreter.node.ExpressionNode;
import org.enso.interpreter.runtime.EnsoContext;
import org.enso.interpreter.runtime.callable.UnresolvedConstructor;
import org.enso.interpreter.runtime.callable.function.Function;
import org.enso.interpreter.runtime.data.Type;
import org.enso.interpreter.runtime.data.text.Text;
import org.enso.interpreter.runtime.error.PanicException;
import org.enso.interpreter.runtime.util.CachingSupplier;
/** A node and a factory for nodes performing type checks (including necessary conversions). */
public final class TypeCheckValueNode extends Node {
private @Child AbstractTypeCheckNode check;
TypeCheckValueNode(AbstractTypeCheckNode check) {
assert check != null;
this.check = check;
}
/**
* Wraps expression node with additional type check.
*
* @param original the expression node
* @param check node performing type check or {@code null}
* @return wrapper around {@code original} or directly {@code original} if there is {@code null}
* check
*/
public static ExpressionNode wrap(ExpressionNode original, TypeCheckValueNode check) {
if (check == null) {
return original;
} else {
return new TypeCheckExpressionNode(original, check);
}
}
/**
* Executes check or conversion of the value.
*
* @param frame frame requesting the conversion
* @param value the value to convert
* @param expr the expression node that produced the {@code value}
* @return {@code null} when the check isn't satisfied and conversion isn't possible or non-{@code
* null} value that can be used as a result
*/
public final Object handleCheckOrConversion(
VirtualFrame frame, Object value, ExpressionNode expr) {
var result = check.executeCheckOrConversion(frame, value, expr);
if (result == null) {
throw panicAtTheEnd(value);
}
return result;
}
/**
* Combines existing type checks into "all of" check.
*
* @param comment description of the check meaning
* @param checks existing type checks
* @return node the composed check or {@code null} if no check is needed
*/
public static TypeCheckValueNode allOf(String comment, TypeCheckValueNode... checks) {
if (checks == null) {
return null;
}
var list = Arrays.asList(checks);
var flatten =
list.stream()
.filter(n -> n != null)
.map(n -> n.check)
.flatMap(
n ->
n instanceof AllOfTypesCheckNode all
? Arrays.asList(all.getChecks()).stream()
: Stream.of(n))
.toList();
var arr = toArray(flatten);
return switch (arr.length) {
case 0 -> null;
case 1 -> new TypeCheckValueNode(arr[0]);
default -> new TypeCheckValueNode(new AllOfTypesCheckNode(comment, arr));
};
}
/**
* Combines existing checks into "one of" check.
*
* @param comment description of the check meaning
* @param checks existing type checks
* @return node the composed check or {@code null} if no check is needed
*/
public static TypeCheckValueNode oneOf(String comment, TypeCheckValueNode... checks) {
if (checks == null) {
return null;
}
var list = Stream.of(checks).filter(n -> n != null).toList();
return switch (list.size()) {
case 0 -> null;
case 1 -> list.get(0);
default -> {
var abstractTypeCheckList = list.stream().map(n -> n.check).toList();
var abstractTypeCheckArr = toArray(abstractTypeCheckList);
yield new TypeCheckValueNode(new OneOfTypesCheckNode(comment, abstractTypeCheckArr));
}
};
}
/**
* Constructs "single type" check.
*
* @param comment description of the check meaning
* @param expectedType the type to check for - it shouldn't be {@code Any}
* @return node performing the check
*/
public static TypeCheckValueNode single(String comment, Type expectedType) {
var typeCheckNodeImpl = SingleTypeCheckNodeGen.create(comment, expectedType);
return new TypeCheckValueNode(typeCheckNodeImpl);
}
/**
* Constructs node to check for {@code polyglot java import} checks.
*
* @param comment description of the check meaning
* @param metaObjectSupplier provider of the meta object to check for
* @return node performing the check
*/
public static TypeCheckValueNode meta(
String comment, Supplier<? extends Object> metaObjectSupplier) {
var cachingSupplier = CachingSupplier.wrap(metaObjectSupplier);
var typeCheckNodeImpl = MetaTypeCheckNodeGen.create(comment, cachingSupplier);
return new TypeCheckValueNode(typeCheckNodeImpl);
}
/**
* Check whether given function is "lazy thunk". E.g. if it is a thunk wrapped by "lazy type
* check".
*
* @param fn function to check
* @return result of the check
*/
public static boolean isWrappedThunk(Function fn) {
if (fn.getSchema() == LazyCheckRootNode.SCHEMA) {
return fn.getPreAppliedArguments()[0] instanceof Function wrappedFn && wrappedFn.isThunk();
}
return false;
}
private final PanicException panicAtTheEnd(Object v) {
var expectedTypeMessage = check.getExpectedTypeMessage();
var ctx = EnsoContext.get(this);
Text msg;
if (v instanceof UnresolvedConstructor) {
msg = Text.create("Cannot find constructor {got} among {exp}");
} else {
msg = check.getComment();
}
var err = ctx.getBuiltins().error().makeTypeErrorOfComment(expectedTypeMessage, v, msg);
throw new PanicException(err, this);
}
private static AbstractTypeCheckNode[] toArray(List<AbstractTypeCheckNode> list) {
if (list == null) {
return new AbstractTypeCheckNode[0];
}
var cnt = (int) list.stream().filter(n -> n != null).count();
var arr = new AbstractTypeCheckNode[cnt];
var it = list.iterator();
for (int i = 0; i < cnt; ) {
var element = it.next();
if (element != null) {
arr[i++] = element;
}
}
return arr;
}
}

View File

@ -188,12 +188,11 @@ public final class UnresolvedConstructor implements EnsoObject {
id = callable.getId(); id = callable.getId();
} }
} }
var fn = ReadArgumentNode.build(0, null, null); var fn = ReadArgumentNode.build(0, null);
var args = new CallArgument[prototype.descs.length]; var args = new CallArgument[prototype.descs.length];
for (var i = 0; i < args.length; i++) { for (var i = 0; i < args.length; i++) {
args[i] = args[i] =
new CallArgument( new CallArgument(prototype.descs[i].getName(), ReadArgumentNode.build(1 + i, null));
prototype.descs[i].getName(), ReadArgumentNode.build(1 + i, null, null));
} }
var expr = ApplicationNode.build(fn, args, DefaultsExecutionMode.EXECUTE); var expr = ApplicationNode.build(fn, args, DefaultsExecutionMode.EXECUTE);
if (id != null) { if (id != null) {

View File

@ -2,7 +2,7 @@ package org.enso.interpreter.runtime.callable.argument;
import java.util.Optional; import java.util.Optional;
import org.enso.interpreter.node.ExpressionNode; import org.enso.interpreter.node.ExpressionNode;
import org.enso.interpreter.node.callable.argument.ReadArgumentCheckNode; import org.enso.interpreter.node.typecheck.TypeCheckValueNode;
/** Tracks the specifics about how arguments are defined at the callable definition site. */ /** Tracks the specifics about how arguments are defined at the callable definition site. */
public final class ArgumentDefinition { public final class ArgumentDefinition {
@ -26,7 +26,7 @@ public final class ArgumentDefinition {
private final int position; private final int position;
private final String name; private final String name;
private final ReadArgumentCheckNode checkType; private final TypeCheckValueNode checkType;
private final ExpressionNode defaultValue; private final ExpressionNode defaultValue;
private final boolean isSuspended; private final boolean isSuspended;
@ -42,7 +42,7 @@ public final class ArgumentDefinition {
public ArgumentDefinition( public ArgumentDefinition(
int position, int position,
String name, String name,
ReadArgumentCheckNode checkType, TypeCheckValueNode checkType,
ExpressionNode defaultValue, ExpressionNode defaultValue,
ExecutionMode executionMode) { ExecutionMode executionMode) {
this.position = position; this.position = position;
@ -103,7 +103,7 @@ public final class ArgumentDefinition {
* *
* @return {@code null} or list of types to check argument against * @return {@code null} or list of types to check argument against
*/ */
public ReadArgumentCheckNode getCheckType() { public TypeCheckValueNode getCheckType() {
return checkType; return checkType;
} }
} }

View File

@ -107,7 +107,7 @@ public final class AtomConstructor implements EnsoObject {
EnsoLanguage language, ModuleScope.Builder scopeBuilder, ArgumentDefinition... args) { EnsoLanguage language, ModuleScope.Builder scopeBuilder, ArgumentDefinition... args) {
ExpressionNode[] reads = new ExpressionNode[args.length]; ExpressionNode[] reads = new ExpressionNode[args.length];
for (int i = 0; i < args.length; i++) { for (int i = 0; i < args.length; i++) {
reads[i] = ReadArgumentNode.build(i, null, null); reads[i] = ReadArgumentNode.build(i, null);
} }
return initializeFields( return initializeFields(
language, language,

View File

@ -5,7 +5,7 @@ import com.oracle.truffle.api.dsl.NodeFactory;
import com.oracle.truffle.api.nodes.Node; import com.oracle.truffle.api.nodes.Node;
import java.util.List; import java.util.List;
import org.enso.interpreter.dsl.atom.LayoutSpec; import org.enso.interpreter.dsl.atom.LayoutSpec;
import org.enso.interpreter.node.callable.argument.ReadArgumentCheckNode; import org.enso.interpreter.node.typecheck.TypeCheckValueNode;
import org.enso.interpreter.runtime.callable.argument.ArgumentDefinition; import org.enso.interpreter.runtime.callable.argument.ArgumentDefinition;
/** /**
@ -232,12 +232,12 @@ class Layout {
private static final class SetterTypeCheckFactory private static final class SetterTypeCheckFactory
implements NodeFactory<UnboxingAtom.FieldSetterNode> { implements NodeFactory<UnboxingAtom.FieldSetterNode> {
private final String argName; private final String argName;
private final ReadArgumentCheckNode typeCheck; private final TypeCheckValueNode typeCheck;
private final NodeFactory<UnboxingAtom.FieldSetterNode> delegate; private final NodeFactory<UnboxingAtom.FieldSetterNode> delegate;
private SetterTypeCheckFactory( private SetterTypeCheckFactory(
ArgumentDefinition arg, ArgumentDefinition arg,
ReadArgumentCheckNode typeCheck, TypeCheckValueNode typeCheck,
NodeFactory<UnboxingAtom.FieldSetterNode> factory) { NodeFactory<UnboxingAtom.FieldSetterNode> factory) {
assert factory != null; assert factory != null;
this.argName = arg.getName(); this.argName = arg.getName();
@ -247,7 +247,7 @@ class Layout {
@Override @Override
public UnboxingAtom.FieldSetterNode createNode(Object... arguments) { public UnboxingAtom.FieldSetterNode createNode(Object... arguments) {
var checkNode = (ReadArgumentCheckNode) typeCheck.copy(); var checkNode = (TypeCheckValueNode) typeCheck.copy();
var setterNode = delegate.createNode(arguments); var setterNode = delegate.createNode(arguments);
return checkNode == null ? setterNode : new CheckFieldSetterNode(setterNode, checkNode); return checkNode == null ? setterNode : new CheckFieldSetterNode(setterNode, checkNode);
} }
@ -274,18 +274,18 @@ class Layout {
} }
private static final class CheckFieldSetterNode extends UnboxingAtom.FieldSetterNode { private static final class CheckFieldSetterNode extends UnboxingAtom.FieldSetterNode {
@Child ReadArgumentCheckNode checkNode; @Child TypeCheckValueNode checkNode;
@Child UnboxingAtom.FieldSetterNode setterNode; @Child UnboxingAtom.FieldSetterNode setterNode;
private CheckFieldSetterNode( private CheckFieldSetterNode(
UnboxingAtom.FieldSetterNode setterNode, ReadArgumentCheckNode checkNode) { UnboxingAtom.FieldSetterNode setterNode, TypeCheckValueNode checkNode) {
this.setterNode = setterNode; this.setterNode = setterNode;
this.checkNode = checkNode; this.checkNode = checkNode;
} }
@Override @Override
public void execute(Atom atom, Object value) { public void execute(Atom atom, Object value) {
var valueOrConvertedValue = checkNode.handleCheckOrConversion(null, value); var valueOrConvertedValue = checkNode.handleCheckOrConversion(null, value, null);
setterNode.execute(atom, valueOrConvertedValue); setterNode.execute(atom, valueOrConvertedValue);
} }
} }

View File

@ -6,8 +6,8 @@ import com.oracle.truffle.api.nodes.Node;
import com.oracle.truffle.api.profiles.BranchProfile; import com.oracle.truffle.api.profiles.BranchProfile;
import java.util.List; import java.util.List;
import org.enso.interpreter.node.callable.InvokeCallableNode; import org.enso.interpreter.node.callable.InvokeCallableNode;
import org.enso.interpreter.node.callable.argument.ReadArgumentCheckNode;
import org.enso.interpreter.node.callable.dispatch.InvokeFunctionNode; import org.enso.interpreter.node.callable.dispatch.InvokeFunctionNode;
import org.enso.interpreter.node.typecheck.TypeCheckValueNode;
import org.enso.interpreter.runtime.callable.argument.CallArgumentInfo; import org.enso.interpreter.runtime.callable.argument.CallArgumentInfo;
import org.enso.interpreter.runtime.callable.function.Function; import org.enso.interpreter.runtime.callable.function.Function;
import org.enso.interpreter.runtime.data.atom.UnboxingAtom.FieldGetterNode; import org.enso.interpreter.runtime.data.atom.UnboxingAtom.FieldGetterNode;
@ -86,7 +86,7 @@ final class SuspendedFieldGetterNode extends UnboxingAtom.FieldGetterNode {
} }
private static boolean shallBeExtracted(Function fn) { private static boolean shallBeExtracted(Function fn) {
return fn.isThunk() || ReadArgumentCheckNode.isWrappedThunk(fn); return fn.isThunk() || TypeCheckValueNode.isWrappedThunk(fn);
} }
@Override @Override

View File

@ -67,7 +67,7 @@ import org.enso.compiler.pass.resolve.{
TypeSignatures TypeSignatures
} }
import org.enso.interpreter.node.callable.argument.ReadArgumentNode import org.enso.interpreter.node.callable.argument.ReadArgumentNode
import org.enso.interpreter.node.callable.argument.ReadArgumentCheckNode import org.enso.interpreter.node.typecheck.TypeCheckValueNode
import org.enso.interpreter.node.callable.function.{ import org.enso.interpreter.node.callable.function.{
BlockNode, BlockNode,
CreateFunctionNode CreateFunctionNode
@ -306,12 +306,12 @@ class IrToTruffle(
.asInstanceOf[FramePointer] .asInstanceOf[FramePointer]
val slotIdx = fp.frameSlotIdx() val slotIdx = fp.frameSlotIdx()
argDefs(idx) = arg argDefs(idx) = arg
val readArg = val readArgNoCheck =
ReadArgumentNode.build( ReadArgumentNode.build(
idx, idx,
arg.getDefaultValue.orElse(null), arg.getDefaultValue.orElse(null)
checkNode
) )
val readArg = TypeCheckValueNode.wrap(readArgNoCheck, checkNode)
val assignmentArg = AssignmentNode.build(readArg, slotIdx) val assignmentArg = AssignmentNode.build(readArg, slotIdx)
val argRead = val argRead =
ReadLocalVariableNode.build(new FramePointer(0, slotIdx)) ReadLocalVariableNode.build(new FramePointer(0, slotIdx))
@ -806,7 +806,7 @@ class IrToTruffle(
methodDef.methodName.name, methodDef.methodName.name,
fn.arguments, fn.arguments,
fn.body, fn.body,
ReadArgumentCheckNode.build(context, "conversion", toType), TypeCheckValueNode.single("conversion", toType),
None, None,
true true
) )
@ -847,27 +847,24 @@ class IrToTruffle(
private def extractAscribedType( private def extractAscribedType(
comment: String, comment: String,
t: Expression t: Expression
): ReadArgumentCheckNode = t match { ): TypeCheckValueNode = t match {
case u: `type`.Set.Union => case u: `type`.Set.Union =>
val oneOf = u.operands.map(extractAscribedType(comment, _)) val oneOf = u.operands.map(extractAscribedType(comment, _))
if (oneOf.contains(null)) { if (oneOf.contains(null)) {
null null
} else { } else {
ReadArgumentCheckNode.oneOf( val arr: Array[TypeCheckValueNode] = oneOf.toArray
comment, TypeCheckValueNode.oneOf(comment, arr: _*)
oneOf.asJava
)
} }
case i: `type`.Set.Intersection => case i: `type`.Set.Intersection =>
ReadArgumentCheckNode.allOf( TypeCheckValueNode.allOf(
comment, comment,
extractAscribedType(comment, i.left), extractAscribedType(comment, i.left),
extractAscribedType(comment, i.right) extractAscribedType(comment, i.right)
) )
case p: Application.Prefix => extractAscribedType(comment, p.function) case p: Application.Prefix => extractAscribedType(comment, p.function)
case _: Tpe.Function => case _: Tpe.Function =>
ReadArgumentCheckNode.build( TypeCheckValueNode.single(
context,
comment, comment,
context.getTopScope().getBuiltins().function() context.getTopScope().getBuiltins().function()
) )
@ -888,13 +885,13 @@ class IrToTruffle(
if (context.getBuiltins().any() == typeOrAny) { if (context.getBuiltins().any() == typeOrAny) {
null null
} else { } else {
ReadArgumentCheckNode.build(context, comment, typeOrAny) TypeCheckValueNode.single(comment, typeOrAny)
} }
case Some( case Some(
BindingsMap BindingsMap
.Resolution(BindingsMap.ResolvedPolyglotSymbol(mod, symbol)) .Resolution(BindingsMap.ResolvedPolyglotSymbol(mod, symbol))
) => ) =>
ReadArgumentCheckNode.meta( TypeCheckValueNode.meta(
comment, comment,
asScope( asScope(
mod.unsafeAsModule().asInstanceOf[TruffleCompilerContext.Module] mod.unsafeAsModule().asInstanceOf[TruffleCompilerContext.Module]
@ -907,7 +904,7 @@ class IrToTruffle(
private def checkAsTypes( private def checkAsTypes(
arg: DefinitionArgument arg: DefinitionArgument
): ReadArgumentCheckNode = { ): TypeCheckValueNode = {
val comment = "`" + arg.name.name + "`" val comment = "`" + arg.name.name + "`"
arg.ascribedType.map(extractAscribedType(comment, _)).getOrElse(null) arg.ascribedType.map(extractAscribedType(comment, _)).getOrElse(null)
} }
@ -1312,7 +1309,7 @@ class IrToTruffle(
extractAscribedType(asc.comment.orNull, asc.signature) extractAscribedType(asc.comment.orNull, asc.signature)
if (checkNode != null) { if (checkNode != null) {
val body = run(asc.typed, binding, subjectToInstrumentation) val body = run(asc.typed, binding, subjectToInstrumentation)
ReadArgumentCheckNode.wrap(body, checkNode) TypeCheckValueNode.wrap(body, checkNode)
} else { } else {
processType(asc) processType(asc)
} }
@ -1341,7 +1338,7 @@ class IrToTruffle(
extractAscribedType(tpe.comment.orNull, tpe.signature) extractAscribedType(tpe.comment.orNull, tpe.signature)
if (checkNode != null) { if (checkNode != null) {
runtimeExpression = runtimeExpression =
ReadArgumentCheckNode.wrap(runtimeExpression, checkNode) TypeCheckValueNode.wrap(runtimeExpression, checkNode)
} }
} }
} }
@ -2192,7 +2189,7 @@ class IrToTruffle(
val initialName: String, val initialName: String,
val arguments: List[DefinitionArgument], val arguments: List[DefinitionArgument],
val body: Expression, val body: Expression,
val typeCheck: ReadArgumentCheckNode, val typeCheck: TypeCheckValueNode,
val effectContext: Option[String], val effectContext: Option[String],
val subjectToInstrumentation: Boolean val subjectToInstrumentation: Boolean
) { ) {
@ -2229,7 +2226,7 @@ class IrToTruffle(
if (typeCheck == null) { if (typeCheck == null) {
(argExpressions.toArray, bodyExpr) (argExpressions.toArray, bodyExpr)
} else { } else {
val bodyWithCheck = ReadArgumentCheckNode.wrap(bodyExpr, typeCheck) val bodyWithCheck = TypeCheckValueNode.wrap(bodyExpr, typeCheck)
(argExpressions.toArray, bodyWithCheck) (argExpressions.toArray, bodyWithCheck)
} }
} }
@ -2255,12 +2252,12 @@ class IrToTruffle(
) )
.asInstanceOf[FramePointer] .asInstanceOf[FramePointer]
val slotIdx = fp.frameSlotIdx() val slotIdx = fp.frameSlotIdx()
val readArg = val readArgNoCheck =
ReadArgumentNode.build( ReadArgumentNode.build(
idx, idx,
arg.getDefaultValue.orElse(null), arg.getDefaultValue.orElse(null)
checkNode
) )
val readArg = TypeCheckValueNode.wrap(readArgNoCheck, checkNode)
val assignArg = AssignmentNode.build(readArg, slotIdx) val assignArg = AssignmentNode.build(readArg, slotIdx)
argExpressions.append(assignArg) argExpressions.append(assignArg)
@ -2589,7 +2586,7 @@ class IrToTruffle(
def run( def run(
inputArg: DefinitionArgument, inputArg: DefinitionArgument,
position: Int, position: Int,
types: ReadArgumentCheckNode types: TypeCheckValueNode
): ArgumentDefinition = ): ArgumentDefinition =
inputArg match { inputArg match {
case arg: DefinitionArgument.Specified => case arg: DefinitionArgument.Specified =>

View File

@ -0,0 +1,33 @@
import Standard.Base.Data.Numbers.Float
import Standard.Base.Data.Numbers.Number
import Standard.Base.Data.Ordering.Comparable
import Standard.Base.Data.Ordering.Ordering
import project.Data.Complex_Helpers
import project.Data.Complex_Helpers.Complex_Comparator
## Sample definition of a complex number with conversions
from Number and implementation of a comparator.
type Complex
private Value re:Float im:Float
new re=0:Float im=0:Float =
c = Complex.Value re im
if im != 0 then c:Complex else
c.as_complex_and_float
+ self (that:Complex) = Complex.new self.re+that.re self.im+that.im
< self (that:Complex) = Complex_Comparator.compare self that == Ordering.Less
> self (that:Complex) = Complex_Comparator.compare self that == Ordering.Greater
>= self (that:Complex) =
ordering = Complex_Comparator.compare self that
ordering == Ordering.Greater || ordering == Ordering.Equal
<= self (that:Complex) =
ordering = Complex_Comparator.compare self that
ordering == Ordering.Less || ordering == Ordering.Equal
Complex.from (that:Number) = Complex.new that
Comparable.from (that:Complex) = Comparable.new that Complex_Comparator
Comparable.from (that:Number) = Comparable.new that Complex_Comparator

View File

@ -0,0 +1,25 @@
private
import Standard.Base.Nothing
import Standard.Base.Data.Numbers.Float
import Standard.Base.Data.Ordering.Ordering
import Standard.Base.Error.Error
import Standard.Base.Errors.Illegal_Argument.Illegal_Argument
import project.Data.Complex.Complex
type Complex_Comparator
compare x:Complex y:Complex = if x.re==y.re && x.im==y.im then Ordering.Equal else
if x.im==0 && y.im==0 then Ordering.compare x.re y.re else
Nothing
hash x:Complex = if x.im == 0 then Ordering.hash x.re else
7*x.re + 11*x.im
## uses the explicit conversion defined in this private module
Complex.as_complex_and_float self =
self : Complex&Float
## explicit "conversion" of `Complex` to `Float` in a private module
used in `as_complex_and_float`
Float.from (that:Complex) =
if that.im == 0 then that.re else
Error.throw <| Illegal_Argument.Error "Cannot convert Complex with imaginary part to Float"

View File

@ -9,6 +9,7 @@ from Standard.Base.Data.Numbers import Number_Parse_Error
from Standard.Test import all from Standard.Test import all
import project.Data.Round_Spec import project.Data.Round_Spec
import project.Data.Complex.Complex
polyglot java import java.math.BigInteger polyglot java import java.math.BigInteger
polyglot java import java.math.BigDecimal polyglot java import java.math.BigDecimal
@ -18,34 +19,6 @@ Integer.is_even self = self % 2 == 0
Float.get_fun_factor self = "Wow, " + self.to_text + " is such a fun number!" Float.get_fun_factor self = "Wow, " + self.to_text + " is such a fun number!"
type Complex
Value re:Float im:Float
new re=0 im=0 = Complex.Value re im
+ self (that:Complex) = Complex.new self.re+that.re self.im+that.im
< self (that:Complex) = Complex_Comparator.compare self that == Ordering.Less
> self (that:Complex) = Complex_Comparator.compare self that == Ordering.Greater
>= self (that:Complex) =
ordering = Complex_Comparator.compare self that
ordering == Ordering.Greater || ordering == Ordering.Equal
<= self (that:Complex) =
ordering = Complex_Comparator.compare self that
ordering == Ordering.Less || ordering == Ordering.Equal
Complex.from (that:Number) = Complex.new that
type Complex_Comparator
compare x:Complex y:Complex = if x.re==y.re && x.im==y.im then Ordering.Equal else
if x.im==0 && y.im==0 then Ordering.compare x.re y.re else
Nothing
hash x:Complex = if x.im == 0 then Ordering.hash x.re else
7*x.re + 11*x.im
Comparable.from (that:Complex) = Comparable.new that Complex_Comparator
Comparable.from (that:Number) = Comparable.new that Complex_Comparator
float_id : Float -> Float float_id : Float -> Float
float_id f = float_id f =
Number_Utils.floatId f Number_Utils.floatId f

View File

@ -11,6 +11,7 @@ import project.Semantic.Import_Loop.Spec as Import_Loop_Spec
import project.Semantic.Meta_Spec import project.Semantic.Meta_Spec
import project.Semantic.Instrumentor_Spec import project.Semantic.Instrumentor_Spec
import project.Semantic.Meta_Location_Spec import project.Semantic.Meta_Location_Spec
import project.Semantic.Multi_Value_Spec
import project.Semantic.Names_Spec import project.Semantic.Names_Spec
import project.Semantic.Equals_Spec import project.Semantic.Equals_Spec
import project.Semantic.Runtime_Spec import project.Semantic.Runtime_Spec
@ -132,6 +133,7 @@ main filter=Nothing =
Meta_Spec.add_specs suite_builder Meta_Spec.add_specs suite_builder
Instrumentor_Spec.add_specs suite_builder Instrumentor_Spec.add_specs suite_builder
Meta_Location_Spec.add_specs suite_builder Meta_Location_Spec.add_specs suite_builder
Multi_Value_Spec.add_specs suite_builder
Names_Spec.add_specs suite_builder Names_Spec.add_specs suite_builder
Numbers_Spec.add_specs suite_builder Numbers_Spec.add_specs suite_builder
Equals_Spec.add_specs suite_builder Equals_Spec.add_specs suite_builder

View File

@ -0,0 +1,20 @@
from Standard.Base import all
from Standard.Test import all
import Standard.Base.Errors.Common.Type_Error
import project.Data.Complex.Complex
add_specs suite_builder =
suite_builder.group "Complex Multi Value" group_builder->
group_builder.specify "Cannot convert to Float if it has imaginary part" <|
c = Complex.new 1 5
Test.expect_panic Type_Error (c:Float)
group_builder.specify "Represents both Complex & Float with only real part" <|
c = Complex.new 1.5 0.0
(c:Complex).re . should_equal 1.5
(c:Float) . should_equal 1.5
main filter=Nothing =
suite = Test.build suite_builder->
add_specs suite_builder
suite.run_with_filter filter