From 73b93f5e6bc9d17313e68b77847f5d40fc016225 Mon Sep 17 00:00:00 2001 From: Jaroslav Tulach Date: Tue, 19 Nov 2024 18:04:42 +0100 Subject: [PATCH] Multi value `Complex` test and robustness refactoring (#11525) While working on #11482 and enhancing the tests suite with more tests based on `type Complex` a [getRootNode() did not terminate in 100000 iterations](https://github.com/enso-org/enso/pull/11525#issuecomment-2476171597) problem was discovered. Detailed investigation revealed that the existing `ReadArgumentCheckNode` infrastructure was able to create a **cycle** of parent pointers in the Truffle AST. The problem was in intricate manipulation of the AST while rewriting internals in `ReadArgumentCheckNode`. This PR avoids such manipulation by _refactoring the type checking code_. `ReadArgumentNode` knows nothing about types anymore. When a type check is needed, `IrToTruffle` adds additional `TypeCheckValueNode.wrap` around the `ReadArgumentNode` - that breaks the **vicious circle**. All the _type checks_ nodes are moved to its own package. All but one of the classes are made package private. The external API for doing _type checking_ is concentrated into `TypeCheckValueNode`. --- .../argument/ReadArgumentCheckNode.java | 578 ------------------ .../callable/argument/ReadArgumentNode.java | 19 +- .../node/typecheck/AbstractTypeCheckNode.java | 93 +++ .../node/typecheck/AllOfTypesCheckNode.java | 58 ++ .../node/typecheck/LazyCheckRootNode.java | 44 ++ .../node/typecheck/MetaTypeCheckNode.java | 58 ++ .../node/typecheck/OneOfTypesCheckNode.java | 54 ++ .../node/typecheck/SingleTypeCheckNode.java | 221 +++++++ .../typecheck/TypeCheckExpressionNode.java | 31 + .../node/typecheck/TypeCheckValueNode.java | 182 ++++++ .../callable/UnresolvedConstructor.java | 5 +- .../callable/argument/ArgumentDefinition.java | 8 +- .../runtime/data/atom/AtomConstructor.java | 2 +- .../interpreter/runtime/data/atom/Layout.java | 14 +- .../data/atom/SuspendedFieldGetterNode.java | 4 +- .../interpreter/runtime/IrToTruffle.scala | 45 +- test/Base_Tests/src/Data/Complex.enso | 33 + test/Base_Tests/src/Data/Complex_Helpers.enso | 25 + test/Base_Tests/src/Data/Numbers_Spec.enso | 29 +- test/Base_Tests/src/Main.enso | 2 + .../src/Semantic/Multi_Value_Spec.enso | 20 + 21 files changed, 862 insertions(+), 663 deletions(-) delete mode 100644 engine/runtime/src/main/java/org/enso/interpreter/node/callable/argument/ReadArgumentCheckNode.java create mode 100644 engine/runtime/src/main/java/org/enso/interpreter/node/typecheck/AbstractTypeCheckNode.java create mode 100644 engine/runtime/src/main/java/org/enso/interpreter/node/typecheck/AllOfTypesCheckNode.java create mode 100644 engine/runtime/src/main/java/org/enso/interpreter/node/typecheck/LazyCheckRootNode.java create mode 100644 engine/runtime/src/main/java/org/enso/interpreter/node/typecheck/MetaTypeCheckNode.java create mode 100644 engine/runtime/src/main/java/org/enso/interpreter/node/typecheck/OneOfTypesCheckNode.java create mode 100644 engine/runtime/src/main/java/org/enso/interpreter/node/typecheck/SingleTypeCheckNode.java create mode 100644 engine/runtime/src/main/java/org/enso/interpreter/node/typecheck/TypeCheckExpressionNode.java create mode 100644 engine/runtime/src/main/java/org/enso/interpreter/node/typecheck/TypeCheckValueNode.java create mode 100644 test/Base_Tests/src/Data/Complex.enso create mode 100644 test/Base_Tests/src/Data/Complex_Helpers.enso create mode 100644 test/Base_Tests/src/Semantic/Multi_Value_Spec.enso diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/callable/argument/ReadArgumentCheckNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/callable/argument/ReadArgumentCheckNode.java deleted file mode 100644 index b353b105e5..0000000000 --- a/engine/runtime/src/main/java/org/enso/interpreter/node/callable/argument/ReadArgumentCheckNode.java +++ /dev/null @@ -1,578 +0,0 @@ -package org.enso.interpreter.node.callable.argument; - -import com.oracle.truffle.api.CompilerAsserts; -import com.oracle.truffle.api.CompilerDirectives; -import com.oracle.truffle.api.TruffleLanguage; -import com.oracle.truffle.api.dsl.Cached; -import com.oracle.truffle.api.dsl.Cached.Shared; -import com.oracle.truffle.api.dsl.Specialization; -import com.oracle.truffle.api.frame.MaterializedFrame; -import com.oracle.truffle.api.frame.VirtualFrame; -import com.oracle.truffle.api.interop.InteropLibrary; -import com.oracle.truffle.api.interop.UnsupportedMessageException; -import com.oracle.truffle.api.nodes.ExplodeLoop; -import com.oracle.truffle.api.nodes.InvalidAssumptionException; -import com.oracle.truffle.api.nodes.Node; -import com.oracle.truffle.api.nodes.NodeUtil; -import com.oracle.truffle.api.nodes.RootNode; -import java.util.Arrays; -import java.util.List; -import java.util.function.Supplier; -import java.util.stream.Collectors; -import java.util.stream.Stream; -import org.enso.interpreter.EnsoLanguage; -import org.enso.interpreter.node.BaseNode.TailStatus; -import org.enso.interpreter.node.EnsoRootNode; -import org.enso.interpreter.node.ExpressionNode; -import org.enso.interpreter.node.callable.ApplicationNode; -import org.enso.interpreter.node.callable.InvokeCallableNode.DefaultsExecutionMode; -import org.enso.interpreter.node.callable.thunk.ThunkExecutorNode; -import org.enso.interpreter.node.expression.builtin.meta.AtomWithAHoleNode; -import org.enso.interpreter.node.expression.builtin.meta.IsValueOfTypeNode; -import org.enso.interpreter.node.expression.literal.LiteralNode; -import org.enso.interpreter.runtime.EnsoContext; -import org.enso.interpreter.runtime.callable.UnresolvedConstructor; -import org.enso.interpreter.runtime.callable.UnresolvedConversion; -import org.enso.interpreter.runtime.callable.argument.ArgumentDefinition; -import org.enso.interpreter.runtime.callable.argument.ArgumentDefinition.ExecutionMode; -import org.enso.interpreter.runtime.callable.argument.CallArgument; -import org.enso.interpreter.runtime.callable.function.Function; -import org.enso.interpreter.runtime.callable.function.FunctionSchema; -import org.enso.interpreter.runtime.data.EnsoMultiValue; -import org.enso.interpreter.runtime.data.Type; -import org.enso.interpreter.runtime.data.text.Text; -import org.enso.interpreter.runtime.error.DataflowError; -import org.enso.interpreter.runtime.error.PanicException; -import org.enso.interpreter.runtime.error.PanicSentinel; -import org.enso.interpreter.runtime.library.dispatch.TypeOfNode; -import org.enso.interpreter.runtime.library.dispatch.TypesLibrary; -import org.enso.interpreter.runtime.util.CachingSupplier; -import org.graalvm.collections.Pair; - -public abstract class ReadArgumentCheckNode extends Node { - private final String comment; - @CompilerDirectives.CompilationFinal private String expectedTypeMessage; - - ReadArgumentCheckNode(String comment) { - this.comment = comment; - } - - /** */ - public static ExpressionNode wrap(ExpressionNode original, ReadArgumentCheckNode check) { - return new TypeCheckExpressionNode(original, check); - } - - /** - * Executes check or conversion of the value. - * - * @param frame frame requesting the conversion - * @param value the value to convert - * @return {@code null} when the check isn't satisfied and conversion isn't possible or non-{@code - * null} value that can be used as a result - */ - public final Object handleCheckOrConversion(VirtualFrame frame, Object value) { - var result = executeCheckOrConversion(frame, value); - if (result == null) { - throw panicAtTheEnd(value); - } - return result; - } - - abstract Object findDirectMatch(VirtualFrame frame, Object value); - - abstract Object executeCheckOrConversion(VirtualFrame frame, Object value); - - abstract String expectedTypeMessage(); - - protected final String joinTypeParts(List parts, String separator) { - assert !parts.isEmpty(); - if (parts.size() == 1) { - return parts.get(0); - } - - var separatorWithSpace = " " + separator + " "; - var builder = new StringBuilder(); - boolean isFirst = true; - for (String part : parts) { - if (isFirst) { - isFirst = false; - } else { - builder.append(separatorWithSpace); - } - - // If the part contains a space, it means it is not a single type but already a more complex - // expression with a separator. - // So to ensure we don't mess up the expression layers, we need to add parentheses around it. - boolean needsParentheses = part.contains(" "); - if (needsParentheses) { - builder.append("("); - } - builder.append(part); - if (needsParentheses) { - builder.append(")"); - } - } - - return builder.toString(); - } - - final PanicException panicAtTheEnd(Object v) { - if (expectedTypeMessage == null) { - CompilerDirectives.transferToInterpreterAndInvalidate(); - expectedTypeMessage = expectedTypeMessage(); - } - var ctx = EnsoContext.get(this); - Text msg; - if (v instanceof UnresolvedConstructor) { - msg = Text.create("Cannot find constructor {got} among {exp}"); - } else { - var where = Text.create(comment == null ? "expression" : comment); - var exp = Text.create("expected "); - var got = Text.create(" to be {exp}, but got {got}"); - msg = Text.create(exp, Text.create(where, got)); - } - var err = ctx.getBuiltins().error().makeTypeErrorOfComment(expectedTypeMessage, v, msg); - throw new PanicException(err, this); - } - - public static ReadArgumentCheckNode allOf(String argumentName, ReadArgumentCheckNode... checks) { - var list = Arrays.asList(checks); - var flatten = - list.stream() - .flatMap( - n -> n instanceof AllOfNode all ? Arrays.asList(all.checks).stream() : Stream.of(n)) - .toList(); - var arr = toArray(flatten); - return switch (arr.length) { - case 0 -> null; - case 1 -> arr[0]; - default -> new AllOfNode(argumentName, arr); - }; - } - - public static ReadArgumentCheckNode oneOf(String comment, List checks) { - var arr = toArray(checks); - return switch (arr.length) { - case 0 -> null; - case 1 -> arr[0]; - default -> new OneOfNode(comment, arr); - }; - } - - public static ReadArgumentCheckNode build(EnsoContext ctx, String comment, Type expectedType) { - assert ctx.getBuiltins().any() != expectedType : "Don't check for Any: " + expectedType; - return ReadArgumentCheckNodeFactory.TypeCheckNodeGen.create(comment, expectedType); - } - - public static ReadArgumentCheckNode meta( - String comment, Supplier metaObjectSupplier) { - var cachingSupplier = CachingSupplier.wrap(metaObjectSupplier); - return ReadArgumentCheckNodeFactory.MetaCheckNodeGen.create(comment, cachingSupplier); - } - - public static boolean isWrappedThunk(Function fn) { - if (fn.getSchema() == LazyCheckRootNode.SCHEMA) { - return fn.getPreAppliedArguments()[0] instanceof Function wrappedFn && wrappedFn.isThunk(); - } - return false; - } - - private static boolean isAllFitValue(Object v) { - return v instanceof DataflowError || AtomWithAHoleNode.isHole(v); - } - - private static ReadArgumentCheckNode[] toArray(List list) { - if (list == null) { - return new ReadArgumentCheckNode[0]; - } - var cnt = (int) list.stream().filter(n -> n != null).count(); - var arr = new ReadArgumentCheckNode[cnt]; - var it = list.iterator(); - for (int i = 0; i < cnt; ) { - var element = it.next(); - if (element != null) { - arr[i++] = element; - } - } - return arr; - } - - static final class AllOfNode extends ReadArgumentCheckNode { - @Children private ReadArgumentCheckNode[] checks; - @Child private TypesLibrary types; - - AllOfNode(String name, ReadArgumentCheckNode[] checks) { - super(name); - this.checks = checks; - this.types = TypesLibrary.getFactory().createDispatched(checks.length); - } - - @Override - Object findDirectMatch(VirtualFrame frame, Object value) { - return null; - } - - @Override - @ExplodeLoop - Object executeCheckOrConversion(VirtualFrame frame, Object value) { - var values = new Object[checks.length]; - var valueTypes = new Type[checks.length]; - var at = 0; - for (var n : checks) { - var result = n.executeCheckOrConversion(frame, value); - if (result == null) { - return null; - } - values[at] = result; - valueTypes[at] = types.getType(result); - at++; - } - return EnsoMultiValue.create(valueTypes, values); - } - - @Override - String expectedTypeMessage() { - var parts = - Arrays.stream(checks) - .map(ReadArgumentCheckNode::expectedTypeMessage) - .collect(Collectors.toList()); - return joinTypeParts(parts, "&"); - } - } - - static final class OneOfNode extends ReadArgumentCheckNode { - @Children private ReadArgumentCheckNode[] checks; - - OneOfNode(String name, ReadArgumentCheckNode[] checks) { - super(name); - this.checks = checks; - } - - @Override - @ExplodeLoop - final Object findDirectMatch(VirtualFrame frame, Object value) { - for (var n : checks) { - var result = n.findDirectMatch(frame, value); - if (result != null) { - return result; - } - } - return null; - } - - @Override - @ExplodeLoop - Object executeCheckOrConversion(VirtualFrame frame, Object value) { - var direct = findDirectMatch(frame, value); - if (direct != null) { - return direct; - } - for (var n : checks) { - var result = n.executeCheckOrConversion(frame, value); - if (result != null) { - return result; - } - } - return null; - } - - @Override - String expectedTypeMessage() { - var parts = - Arrays.stream(checks) - .map(ReadArgumentCheckNode::expectedTypeMessage) - .collect(Collectors.toList()); - return joinTypeParts(parts, "|"); - } - } - - abstract static class TypeCheckNode extends ReadArgumentCheckNode { - private final Type expectedType; - @Child IsValueOfTypeNode checkType; - @CompilerDirectives.CompilationFinal private String expectedTypeMessage; - @CompilerDirectives.CompilationFinal private LazyCheckRootNode lazyCheck; - @Child private EnsoMultiValue.CastToNode castTo; - - TypeCheckNode(String name, Type expectedType) { - super(name); - this.checkType = IsValueOfTypeNode.build(); - this.expectedType = expectedType; - } - - @Specialization - Object doPanicSentinel(VirtualFrame frame, PanicSentinel panicSentinel) { - throw panicSentinel; - } - - @Specialization - Object doUnresolvedConstructor( - VirtualFrame frame, - UnresolvedConstructor unresolved, - @Cached UnresolvedConstructor.ConstructNode construct) { - var state = Function.ArgumentsHelper.getState(frame.getArguments()); - return construct.execute(frame, state, expectedType, unresolved); - } - - @Specialization(rewriteOn = InvalidAssumptionException.class) - Object doCheckNoConversionNeeded(VirtualFrame frame, Object v) - throws InvalidAssumptionException { - var ret = findDirectMatch(frame, v); - if (ret != null) { - return ret; - } else { - throw new InvalidAssumptionException(); - } - } - - @Specialization( - limit = "10", - guards = {"cachedType != null", "findType(typeOfNode, v, cachedType) == cachedType"}) - Object doWithConversionCached( - VirtualFrame frame, - Object v, - @Shared("typeOfNode") @Cached TypeOfNode typeOfNode, - @Cached(value = "findType(typeOfNode, v)", dimensions = 1) Type[] cachedType, - @Cached("findConversionNode(cachedType)") ApplicationNode convertNode) { - return handleWithConversion(frame, v, convertNode); - } - - @Specialization(replaces = "doWithConversionCached") - Object doWithConversionUncached( - VirtualFrame frame, Object v, @Shared("typeOfNode") @Cached TypeOfNode typeOfNode) { - var type = findType(typeOfNode, v); - return doWithConversionUncachedBoundary(frame == null ? null : frame.materialize(), v, type); - } - - @ExplodeLoop - final Object findDirectMatch(VirtualFrame frame, Object v) { - if (isAllFitValue(v)) { - return v; - } - if (v instanceof Function fn && fn.isThunk()) { - if (lazyCheck == null) { - CompilerDirectives.transferToInterpreter(); - var enso = EnsoLanguage.get(this); - var node = (ReadArgumentCheckNode) copy(); - lazyCheck = new LazyCheckRootNode(enso, node); - } - var lazyCheckFn = lazyCheck.wrapThunk(fn); - return lazyCheckFn; - } - if (v instanceof EnsoMultiValue mv) { - if (castTo == null) { - CompilerDirectives.transferToInterpreter(); - castTo = insert(EnsoMultiValue.CastToNode.create()); - } - var result = castTo.executeCast(expectedType, mv); - if (result != null) { - return result; - } - } - if (checkType.execute(expectedType, v)) { - return v; - } - return null; - } - - private Pair findConversion(Type from) { - if (expectedType == from) { - return null; - } - var ctx = EnsoContext.get(this); - - if (getRootNode() instanceof EnsoRootNode root) { - var convert = UnresolvedConversion.build(root.getModuleScope()); - var conv = convert.resolveFor(ctx, expectedType, from); - if (conv != null) { - return Pair.create(conv, expectedType); - } - } - return null; - } - - ApplicationNode findConversionNode(Type[] allTypes) { - if (allTypes == null) { - allTypes = new Type[] {null}; - } - for (var from : allTypes) { - var convAndType = findConversion(from); - - if (convAndType != null) { - if (NodeUtil.findParent(this, ReadArgumentNode.class) instanceof ReadArgumentNode ran) { - CompilerAsserts.neverPartOfCompilation(); - var convNode = LiteralNode.build(convAndType.getLeft()); - var intoNode = LiteralNode.build(convAndType.getRight()); - var valueNode = ran.plainRead(); - var args = - new CallArgument[] { - new CallArgument(null, intoNode), new CallArgument(null, valueNode) - }; - return ApplicationNode.build(convNode, args, DefaultsExecutionMode.EXECUTE); - } else if (NodeUtil.findParent(this, TypeCheckExpressionNode.class) - instanceof TypeCheckExpressionNode tcen) { - CompilerAsserts.neverPartOfCompilation(); - var convNode = LiteralNode.build(convAndType.getLeft()); - var intoNode = LiteralNode.build(convAndType.getRight()); - var valueNode = tcen.original; - var args = - new CallArgument[] { - new CallArgument(null, intoNode), new CallArgument(null, valueNode) - }; - return ApplicationNode.build(convNode, args, DefaultsExecutionMode.EXECUTE); - } - } - } - return null; - } - - Type[] findType(TypeOfNode typeOfNode, Object v) { - return findType(typeOfNode, v, null); - } - - Type[] findType(TypeOfNode typeOfNode, Object v, Type[] previous) { - if (v instanceof EnsoMultiValue multi) { - return multi.allTypes(); - } - if (v instanceof UnresolvedConstructor) { - return null; - } - if (typeOfNode.execute(v) instanceof Type from) { - if (previous != null && previous.length == 1 && previous[0] == from) { - return previous; - } else { - return new Type[] {from}; - } - } - return null; - } - - private Object handleWithConversion(VirtualFrame frame, Object v, ApplicationNode convertNode) - throws PanicException { - if (convertNode == null) { - var ret = findDirectMatch(frame, v); - if (ret != null) { - return ret; - } - return null; - } else { - var converted = convertNode.executeGeneric(frame); - return converted; - } - } - - @CompilerDirectives.TruffleBoundary - private Object doWithConversionUncachedBoundary( - MaterializedFrame frame, Object v, Type[] type) { - var convertNode = findConversionNode(type); - return handleWithConversion(frame, v, convertNode); - } - - @Override - String expectedTypeMessage() { - if (expectedTypeMessage != null) { - return expectedTypeMessage; - } - CompilerDirectives.transferToInterpreterAndInvalidate(); - expectedTypeMessage = expectedType.toString(); - return expectedTypeMessage; - } - } - - abstract static class MetaCheckNode extends ReadArgumentCheckNode { - private final CachingSupplier expectedSupplier; - @CompilerDirectives.CompilationFinal private String expectedTypeMessage; - - MetaCheckNode(String name, CachingSupplier expectedMetaSupplier) { - super(name); - this.expectedSupplier = expectedMetaSupplier; - } - - @Override - Object findDirectMatch(VirtualFrame frame, Object value) { - return executeCheckOrConversion(frame, value); - } - - @Specialization() - Object verifyMetaObject(VirtualFrame frame, Object v, @Cached IsValueOfTypeNode isA) { - if (isAllFitValue(v)) { - return v; - } - if (isA.execute(expectedSupplier.get(), v)) { - return v; - } else { - return null; - } - } - - @Override - String expectedTypeMessage() { - if (expectedTypeMessage != null) { - return expectedTypeMessage; - } - CompilerDirectives.transferToInterpreterAndInvalidate(); - var iop = InteropLibrary.getUncached(); - try { - expectedTypeMessage = iop.asString(iop.getMetaQualifiedName(expectedSupplier.get())); - } catch (UnsupportedMessageException ex) { - expectedTypeMessage = expectedSupplier.get().toString(); - } - return expectedTypeMessage; - } - } - - private static final class LazyCheckRootNode extends RootNode { - - @Child private ThunkExecutorNode evalThunk; - @Child private ReadArgumentCheckNode check; - - static final FunctionSchema SCHEMA = - FunctionSchema.newBuilder() - .argumentDefinitions( - new ArgumentDefinition(0, "delegate", null, null, ExecutionMode.EXECUTE)) - .hasPreapplied(true) - .build(); - - LazyCheckRootNode(TruffleLanguage language, ReadArgumentCheckNode check) { - super(language); - this.check = check; - this.evalThunk = ThunkExecutorNode.build(); - } - - Function wrapThunk(Function thunk) { - return new Function(getCallTarget(), thunk.getScope(), SCHEMA, new Object[] {thunk}, null); - } - - @Override - public Object execute(VirtualFrame frame) { - var state = Function.ArgumentsHelper.getState(frame.getArguments()); - var args = Function.ArgumentsHelper.getPositionalArguments(frame.getArguments()); - assert args.length == 1; - assert args[0] instanceof Function fn && fn.isThunk(); - var raw = evalThunk.executeThunk(frame, args[0], state, TailStatus.NOT_TAIL); - var result = check.handleCheckOrConversion(frame, raw); - return result; - } - } - - private static final class TypeCheckExpressionNode extends ExpressionNode { - @Child private ExpressionNode original; - @Child private ReadArgumentCheckNode check; - - TypeCheckExpressionNode(ExpressionNode original, ReadArgumentCheckNode check) { - this.check = check; - this.original = original; - } - - @Override - public Object executeGeneric(VirtualFrame frame) { - var value = original.executeGeneric(frame); - var result = check.handleCheckOrConversion(frame, value); - return result; - } - - @Override - public boolean isInstrumentable() { - return false; - } - } -} diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/callable/argument/ReadArgumentNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/callable/argument/ReadArgumentNode.java index c2b7c122e6..f3d194ff70 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/node/callable/argument/ReadArgumentNode.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/callable/argument/ReadArgumentNode.java @@ -14,13 +14,11 @@ import org.enso.interpreter.runtime.callable.function.Function; public final class ReadArgumentNode extends ExpressionNode { private final int index; @Child ExpressionNode defaultValue; - @Child ReadArgumentCheckNode checkType; private final CountingConditionProfile defaultingProfile = CountingConditionProfile.create(); - private ReadArgumentNode(int position, ExpressionNode defaultValue, ReadArgumentCheckNode check) { + private ReadArgumentNode(int position, ExpressionNode defaultValue) { this.index = position; this.defaultValue = defaultValue; - this.checkType = check; } /** @@ -28,18 +26,10 @@ public final class ReadArgumentNode extends ExpressionNode { * * @param position the argument's position at the definition site * @param defaultValue the default value provided for that argument - * @param check {@code null} or node to check type of input * @return a node representing the argument at position {@code idx} */ - public static ReadArgumentNode build( - int position, ExpressionNode defaultValue, ReadArgumentCheckNode check) { - return new ReadArgumentNode(position, defaultValue, check); - } - - ReadArgumentNode plainRead() { - var node = (ReadArgumentNode) this.copy(); - node.checkType = null; - return node; + public static ReadArgumentNode build(int position, ExpressionNode defaultValue) { + return new ReadArgumentNode(position, defaultValue); } /** @@ -68,9 +58,6 @@ public final class ReadArgumentNode extends ExpressionNode { v = arguments[index]; } } - if (checkType != null) { - v = checkType.handleCheckOrConversion(frame, v); - } return v; } diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/typecheck/AbstractTypeCheckNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/typecheck/AbstractTypeCheckNode.java new file mode 100644 index 0000000000..2efb2602a1 --- /dev/null +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/typecheck/AbstractTypeCheckNode.java @@ -0,0 +1,93 @@ +package org.enso.interpreter.node.typecheck; + +import com.oracle.truffle.api.CompilerDirectives; +import com.oracle.truffle.api.frame.VirtualFrame; +import com.oracle.truffle.api.nodes.Node; +import java.util.List; +import org.enso.interpreter.node.ExpressionNode; +import org.enso.interpreter.node.expression.builtin.meta.AtomWithAHoleNode; +import org.enso.interpreter.runtime.data.text.Text; +import org.enso.interpreter.runtime.error.DataflowError; + +/** + * Root of hierarchy of nodes checking types. This class (and its subclasses) are an implementation + * detail. The API to perform the is in {@link TypeCheckNode}. + */ +abstract sealed class AbstractTypeCheckNode extends Node + permits OneOfTypesCheckNode, AllOfTypesCheckNode, SingleTypeCheckNode, MetaTypeCheckNode { + private final String comment; + @CompilerDirectives.CompilationFinal private String expectedTypeMessage; + + AbstractTypeCheckNode(String comment) { + this.comment = comment; + } + + abstract Object findDirectMatch(VirtualFrame frame, Object value); + + abstract Object executeCheckOrConversion( + VirtualFrame frame, Object value, ExpressionNode valueNode); + + abstract String expectedTypeMessage(); + + /** + * The error message for this node's check. Ready for being used at "fast path". + * + * @return + */ + final String getExpectedTypeMessage() { + if (expectedTypeMessage == null) { + CompilerDirectives.transferToInterpreterAndInvalidate(); + expectedTypeMessage = expectedTypeMessage(); + } + return expectedTypeMessage; + } + + /** + * Composes the comment message describing the kind of check this node performs. Ready for being + * used at "fast path". + * + * @return description of this node's "expectations" + */ + final Text getComment() { + var where = Text.create(comment == null ? "expression" : comment); + var exp = Text.create("expected "); + var got = Text.create(" to be {exp}, but got {got}"); + return Text.create(exp, Text.create(where, got)); + } + + final String joinTypeParts(List parts, String separator) { + assert !parts.isEmpty(); + if (parts.size() == 1) { + return parts.get(0); + } + + var separatorWithSpace = " " + separator + " "; + var builder = new StringBuilder(); + boolean isFirst = true; + for (String part : parts) { + if (isFirst) { + isFirst = false; + } else { + builder.append(separatorWithSpace); + } + + // If the part contains a space, it means it is not a single type but already a more complex + // expression with a separator. + // So to ensure we don't mess up the expression layers, we need to add parentheses around it. + boolean needsParentheses = part.contains(" "); + if (needsParentheses) { + builder.append("("); + } + builder.append(part); + if (needsParentheses) { + builder.append(")"); + } + } + + return builder.toString(); + } + + static boolean isAllFitValue(Object v) { + return v instanceof DataflowError || AtomWithAHoleNode.isHole(v); + } +} diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/typecheck/AllOfTypesCheckNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/typecheck/AllOfTypesCheckNode.java new file mode 100644 index 0000000000..f06c0f40f8 --- /dev/null +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/typecheck/AllOfTypesCheckNode.java @@ -0,0 +1,58 @@ +package org.enso.interpreter.node.typecheck; + +import com.oracle.truffle.api.frame.VirtualFrame; +import com.oracle.truffle.api.nodes.ExplodeLoop; +import java.util.Arrays; +import java.util.stream.Collectors; +import org.enso.interpreter.node.ExpressionNode; +import org.enso.interpreter.runtime.data.EnsoMultiValue; +import org.enso.interpreter.runtime.data.Type; +import org.enso.interpreter.runtime.library.dispatch.TypesLibrary; + +final class AllOfTypesCheckNode extends AbstractTypeCheckNode { + + @Children private AbstractTypeCheckNode[] checks; + @Child private TypesLibrary types; + + AllOfTypesCheckNode(String name, AbstractTypeCheckNode[] checks) { + super(name); + this.checks = checks; + this.types = TypesLibrary.getFactory().createDispatched(checks.length); + } + + AbstractTypeCheckNode[] getChecks() { + return checks; + } + + @Override + Object findDirectMatch(VirtualFrame frame, Object value) { + return null; + } + + @Override + @ExplodeLoop + Object executeCheckOrConversion(VirtualFrame frame, Object value, ExpressionNode expr) { + var values = new Object[checks.length]; + var valueTypes = new Type[checks.length]; + var at = 0; + for (var n : checks) { + var result = n.executeCheckOrConversion(frame, value, expr); + if (result == null) { + return null; + } + values[at] = result; + valueTypes[at] = types.getType(result); + at++; + } + return EnsoMultiValue.create(valueTypes, values); + } + + @Override + String expectedTypeMessage() { + var parts = + Arrays.stream(checks) + .map(AbstractTypeCheckNode::expectedTypeMessage) + .collect(Collectors.toList()); + return joinTypeParts(parts, "&"); + } +} diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/typecheck/LazyCheckRootNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/typecheck/LazyCheckRootNode.java new file mode 100644 index 0000000000..3a397bb0fc --- /dev/null +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/typecheck/LazyCheckRootNode.java @@ -0,0 +1,44 @@ +package org.enso.interpreter.node.typecheck; + +import com.oracle.truffle.api.TruffleLanguage; +import com.oracle.truffle.api.frame.VirtualFrame; +import com.oracle.truffle.api.nodes.RootNode; +import org.enso.interpreter.node.BaseNode; +import org.enso.interpreter.node.callable.thunk.ThunkExecutorNode; +import org.enso.interpreter.runtime.callable.argument.ArgumentDefinition; +import org.enso.interpreter.runtime.callable.function.Function; +import org.enso.interpreter.runtime.callable.function.FunctionSchema; + +final class LazyCheckRootNode extends RootNode { + + @Child private ThunkExecutorNode evalThunk; + @Child private TypeCheckValueNode check; + static final FunctionSchema SCHEMA = + FunctionSchema.newBuilder() + .argumentDefinitions( + new ArgumentDefinition( + 0, "delegate", null, null, ArgumentDefinition.ExecutionMode.EXECUTE)) + .hasPreapplied(true) + .build(); + + LazyCheckRootNode(TruffleLanguage language, TypeCheckValueNode check) { + super(language); + this.check = check; + this.evalThunk = ThunkExecutorNode.build(); + } + + Function wrapThunk(Function thunk) { + return new Function(getCallTarget(), thunk.getScope(), SCHEMA, new Object[] {thunk}, null); + } + + @Override + public Object execute(VirtualFrame frame) { + var state = Function.ArgumentsHelper.getState(frame.getArguments()); + var args = Function.ArgumentsHelper.getPositionalArguments(frame.getArguments()); + assert args.length == 1; + assert args[0] instanceof Function fn && fn.isThunk(); + var raw = evalThunk.executeThunk(frame, args[0], state, BaseNode.TailStatus.NOT_TAIL); + var result = check.handleCheckOrConversion(frame, raw, null); + return result; + } +} diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/typecheck/MetaTypeCheckNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/typecheck/MetaTypeCheckNode.java new file mode 100644 index 0000000000..00ac749966 --- /dev/null +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/typecheck/MetaTypeCheckNode.java @@ -0,0 +1,58 @@ +package org.enso.interpreter.node.typecheck; + +import com.oracle.truffle.api.CompilerDirectives; +import com.oracle.truffle.api.dsl.Cached; +import com.oracle.truffle.api.dsl.Specialization; +import com.oracle.truffle.api.frame.VirtualFrame; +import com.oracle.truffle.api.interop.InteropLibrary; +import com.oracle.truffle.api.interop.UnsupportedMessageException; +import org.enso.interpreter.node.expression.builtin.meta.IsValueOfTypeNode; +import org.enso.interpreter.runtime.util.CachingSupplier; + +/** + * Node for checking {@code polyglot java import} types. This class (and its subclasses) + * are an implementation detail. The API to perform the is in {@link TypeCheckNode}. + */ +non-sealed abstract class MetaTypeCheckNode extends AbstractTypeCheckNode { + private final CachingSupplier expectedSupplier; + @CompilerDirectives.CompilationFinal private String expectedTypeMessage; + + MetaTypeCheckNode(String name, CachingSupplier expectedMetaSupplier) { + super(name); + this.expectedSupplier = expectedMetaSupplier; + } + + abstract Object executeCheckOrConversion(VirtualFrame frame, Object value); + + @Override + Object findDirectMatch(VirtualFrame frame, Object value) { + return executeCheckOrConversion(frame, value); + } + + @Specialization + Object verifyMetaObject(VirtualFrame frame, Object v, @Cached IsValueOfTypeNode isA) { + if (isAllFitValue(v)) { + return v; + } + if (isA.execute(expectedSupplier.get(), v)) { + return v; + } else { + return null; + } + } + + @Override + String expectedTypeMessage() { + if (expectedTypeMessage != null) { + return expectedTypeMessage; + } + CompilerDirectives.transferToInterpreterAndInvalidate(); + com.oracle.truffle.api.interop.InteropLibrary iop = InteropLibrary.getUncached(); + try { + expectedTypeMessage = iop.asString(iop.getMetaQualifiedName(expectedSupplier.get())); + } catch (UnsupportedMessageException ex) { + expectedTypeMessage = expectedSupplier.get().toString(); + } + return expectedTypeMessage; + } +} diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/typecheck/OneOfTypesCheckNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/typecheck/OneOfTypesCheckNode.java new file mode 100644 index 0000000000..6cbe1850c0 --- /dev/null +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/typecheck/OneOfTypesCheckNode.java @@ -0,0 +1,54 @@ +package org.enso.interpreter.node.typecheck; + +import com.oracle.truffle.api.frame.VirtualFrame; +import com.oracle.truffle.api.nodes.ExplodeLoop; +import java.util.Arrays; +import java.util.stream.Collectors; +import org.enso.interpreter.node.ExpressionNode; + +final class OneOfTypesCheckNode extends AbstractTypeCheckNode { + + @Children private AbstractTypeCheckNode[] checks; + + OneOfTypesCheckNode(String name, AbstractTypeCheckNode[] checks) { + super(name); + this.checks = checks; + } + + @Override + @ExplodeLoop + final Object findDirectMatch(VirtualFrame frame, Object value) { + for (var n : checks) { + java.lang.Object result = n.findDirectMatch(frame, value); + if (result != null) { + return result; + } + } + return null; + } + + @Override + @ExplodeLoop + Object executeCheckOrConversion(VirtualFrame frame, Object value, ExpressionNode expr) { + java.lang.Object direct = findDirectMatch(frame, value); + if (direct != null) { + return direct; + } + for (var n : checks) { + java.lang.Object result = n.executeCheckOrConversion(frame, value, expr); + if (result != null) { + return result; + } + } + return null; + } + + @Override + String expectedTypeMessage() { + java.util.List parts = + Arrays.stream(checks) + .map(AbstractTypeCheckNode::expectedTypeMessage) + .collect(Collectors.toList()); + return joinTypeParts(parts, "|"); + } +} diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/typecheck/SingleTypeCheckNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/typecheck/SingleTypeCheckNode.java new file mode 100644 index 0000000000..0996660009 --- /dev/null +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/typecheck/SingleTypeCheckNode.java @@ -0,0 +1,221 @@ +package org.enso.interpreter.node.typecheck; + +import com.oracle.truffle.api.CompilerAsserts; +import com.oracle.truffle.api.CompilerDirectives; +import com.oracle.truffle.api.dsl.Cached; +import com.oracle.truffle.api.dsl.Specialization; +import com.oracle.truffle.api.frame.MaterializedFrame; +import com.oracle.truffle.api.frame.VirtualFrame; +import com.oracle.truffle.api.nodes.ExplodeLoop; +import com.oracle.truffle.api.nodes.InvalidAssumptionException; +import com.oracle.truffle.api.nodes.Node; +import org.enso.interpreter.EnsoLanguage; +import org.enso.interpreter.node.EnsoRootNode; +import org.enso.interpreter.node.ExpressionNode; +import org.enso.interpreter.node.callable.ApplicationNode; +import org.enso.interpreter.node.callable.InvokeCallableNode; +import org.enso.interpreter.node.expression.builtin.meta.IsValueOfTypeNode; +import org.enso.interpreter.node.expression.literal.LiteralNode; +import org.enso.interpreter.runtime.EnsoContext; +import org.enso.interpreter.runtime.callable.UnresolvedConstructor; +import org.enso.interpreter.runtime.callable.UnresolvedConversion; +import org.enso.interpreter.runtime.callable.argument.CallArgument; +import org.enso.interpreter.runtime.callable.function.Function; +import org.enso.interpreter.runtime.data.EnsoMultiValue; +import org.enso.interpreter.runtime.data.Type; +import org.enso.interpreter.runtime.error.PanicException; +import org.enso.interpreter.runtime.error.PanicSentinel; +import org.enso.interpreter.runtime.library.dispatch.TypeOfNode; +import org.graalvm.collections.Pair; + +non-sealed abstract class SingleTypeCheckNode extends AbstractTypeCheckNode { + private final Type expectedType; + @Node.Child IsValueOfTypeNode checkType; + @CompilerDirectives.CompilationFinal private String expectedTypeMessage; + @CompilerDirectives.CompilationFinal private LazyCheckRootNode lazyCheck; + @Node.Child private EnsoMultiValue.CastToNode castTo; + + SingleTypeCheckNode(String name, Type expectedType) { + super(name); + this.checkType = IsValueOfTypeNode.build(); + this.expectedType = expectedType; + } + + abstract Object executeCheckOrConversion( + VirtualFrame frame, Object value, ExpressionNode valueSource); + + @Specialization + Object doPanicSentinel(VirtualFrame frame, PanicSentinel panicSentinel, ExpressionNode ignore) { + throw panicSentinel; + } + + @Specialization + Object doUnresolvedConstructor( + VirtualFrame frame, + UnresolvedConstructor unresolved, + ExpressionNode ignore, + @Cached UnresolvedConstructor.ConstructNode construct) { + var state = Function.ArgumentsHelper.getState(frame.getArguments()); + return construct.execute(frame, state, expectedType, unresolved); + } + + @Specialization(rewriteOn = InvalidAssumptionException.class) + Object doCheckNoConversionNeeded(VirtualFrame frame, Object v, ExpressionNode ignore) + throws InvalidAssumptionException { + var ret = findDirectMatch(frame, v); + if (ret != null) { + return ret; + } else { + throw new InvalidAssumptionException(); + } + } + + @Specialization( + limit = "10", + guards = {"cachedType != null", "findType(typeOfNode, v, cachedType) == cachedType"}) + Object doWithConversionCached( + VirtualFrame frame, + Object v, + ExpressionNode valueSource, + @Cached.Shared("typeOfNode") @Cached TypeOfNode typeOfNode, + @Cached(value = "findType(typeOfNode, v)", dimensions = 1) Type[] cachedType, + @Cached("findConversionNode(valueSource, cachedType)") ApplicationNode convertNode) { + return handleWithConversion(frame, v, convertNode); + } + + @Specialization(replaces = "doWithConversionCached") + Object doWithConversionUncached( + VirtualFrame frame, + Object v, + ExpressionNode expr, + @Cached.Shared("typeOfNode") @Cached TypeOfNode typeOfNode) { + var type = findType(typeOfNode, v); + return doWithConversionUncachedBoundary( + frame == null ? null : frame.materialize(), v, expr, type); + } + + @ExplodeLoop + final Object findDirectMatch(VirtualFrame frame, Object v) { + if (isAllFitValue(v)) { + return v; + } + if (v instanceof Function fn && fn.isThunk()) { + if (lazyCheck == null) { + CompilerDirectives.transferToInterpreter(); + var enso = EnsoLanguage.get(this); + var node = (AbstractTypeCheckNode) copy(); + lazyCheck = new LazyCheckRootNode(enso, new TypeCheckValueNode(node)); + } + var lazyCheckFn = lazyCheck.wrapThunk(fn); + return lazyCheckFn; + } + assert EnsoContext.get(this).getBuiltins().any() != expectedType : "Don't check for Any: " + expectedType; + if (v instanceof EnsoMultiValue mv) { + if (castTo == null) { + CompilerDirectives.transferToInterpreter(); + castTo = insert(EnsoMultiValue.CastToNode.create()); + } + var result = castTo.executeCast(expectedType, mv); + if (result != null) { + return result; + } + } + if (checkType.execute(expectedType, v)) { + return v; + } + return null; + } + + private Pair findConversion(Type from) { + if (expectedType == from) { + return null; + } + var ctx = EnsoContext.get(this); + + if (getRootNode() instanceof EnsoRootNode root) { + var convert = UnresolvedConversion.build(root.getModuleScope()); + var conv = convert.resolveFor(ctx, expectedType, from); + if (conv != null) { + return Pair.create(conv, expectedType); + } + } + return null; + } + + ApplicationNode findConversionNode(ExpressionNode valueNode, Type[] allTypes) { + if (valueNode == null) { + return null; + } + if (allTypes == null) { + allTypes = new Type[] {null}; + } + for (var from : allTypes) { + var convAndType = findConversion(from); + + if (convAndType != null) { + CompilerAsserts.neverPartOfCompilation(); + var convNode = LiteralNode.build(convAndType.getLeft()); + var intoNode = LiteralNode.build(convAndType.getRight()); + var args = + new CallArgument[] { + new CallArgument(null, intoNode), new CallArgument(null, valueNode) + }; + return ApplicationNode.build( + convNode, args, InvokeCallableNode.DefaultsExecutionMode.EXECUTE); + } + } + return null; + } + + Type[] findType(TypeOfNode typeOfNode, Object v) { + return findType(typeOfNode, v, null); + } + + Type[] findType(TypeOfNode typeOfNode, Object v, Type[] previous) { + if (v instanceof EnsoMultiValue multi) { + return multi.allTypes(); + } + if (v instanceof UnresolvedConstructor) { + return null; + } + if (typeOfNode.execute(v) instanceof Type from) { + if (previous != null && previous.length == 1 && previous[0] == from) { + return previous; + } else { + return new Type[] {from}; + } + } + return null; + } + + private Object handleWithConversion(VirtualFrame frame, Object v, ApplicationNode convertNode) + throws PanicException { + if (convertNode == null) { + var ret = findDirectMatch(frame, v); + if (ret != null) { + return ret; + } + return null; + } else { + var converted = convertNode.executeGeneric(frame); + return converted; + } + } + + @CompilerDirectives.TruffleBoundary + private Object doWithConversionUncachedBoundary( + MaterializedFrame frame, Object v, ExpressionNode expr, Type[] type) { + var convertNode = findConversionNode(expr, type); + return handleWithConversion(frame, v, convertNode); + } + + @Override + String expectedTypeMessage() { + if (expectedTypeMessage != null) { + return expectedTypeMessage; + } + CompilerDirectives.transferToInterpreterAndInvalidate(); + expectedTypeMessage = expectedType.toString(); + return expectedTypeMessage; + } +} diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/typecheck/TypeCheckExpressionNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/typecheck/TypeCheckExpressionNode.java new file mode 100644 index 0000000000..cfc10f5827 --- /dev/null +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/typecheck/TypeCheckExpressionNode.java @@ -0,0 +1,31 @@ +package org.enso.interpreter.node.typecheck; + +import com.oracle.truffle.api.frame.VirtualFrame; +import org.enso.interpreter.node.ExpressionNode; + +final class TypeCheckExpressionNode extends ExpressionNode { + + @Child private ExpressionNode original; + @Child private TypeCheckValueNode check; + + TypeCheckExpressionNode(ExpressionNode original, TypeCheckValueNode check) { + this.check = check; + this.original = original; + } + + ExpressionNode getOriginal() { + return original; + } + + @Override + public Object executeGeneric(VirtualFrame frame) { + java.lang.Object value = original.executeGeneric(frame); + java.lang.Object result = check.handleCheckOrConversion(frame, value, original); + return result; + } + + @Override + public boolean isInstrumentable() { + return false; + } +} diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/typecheck/TypeCheckValueNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/typecheck/TypeCheckValueNode.java new file mode 100644 index 0000000000..d92fc1f97d --- /dev/null +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/typecheck/TypeCheckValueNode.java @@ -0,0 +1,182 @@ +package org.enso.interpreter.node.typecheck; + +import com.oracle.truffle.api.frame.VirtualFrame; +import com.oracle.truffle.api.nodes.Node; +import java.util.Arrays; +import java.util.List; +import java.util.function.Supplier; +import java.util.stream.Stream; +import org.enso.interpreter.node.ExpressionNode; +import org.enso.interpreter.runtime.EnsoContext; +import org.enso.interpreter.runtime.callable.UnresolvedConstructor; +import org.enso.interpreter.runtime.callable.function.Function; +import org.enso.interpreter.runtime.data.Type; +import org.enso.interpreter.runtime.data.text.Text; +import org.enso.interpreter.runtime.error.PanicException; +import org.enso.interpreter.runtime.util.CachingSupplier; + +/** A node and a factory for nodes performing type checks (including necessary conversions). */ +public final class TypeCheckValueNode extends Node { + private @Child AbstractTypeCheckNode check; + + TypeCheckValueNode(AbstractTypeCheckNode check) { + assert check != null; + this.check = check; + } + + /** + * Wraps expression node with additional type check. + * + * @param original the expression node + * @param check node performing type check or {@code null} + * @return wrapper around {@code original} or directly {@code original} if there is {@code null} + * check + */ + public static ExpressionNode wrap(ExpressionNode original, TypeCheckValueNode check) { + if (check == null) { + return original; + } else { + return new TypeCheckExpressionNode(original, check); + } + } + + /** + * Executes check or conversion of the value. + * + * @param frame frame requesting the conversion + * @param value the value to convert + * @param expr the expression node that produced the {@code value} + * @return {@code null} when the check isn't satisfied and conversion isn't possible or non-{@code + * null} value that can be used as a result + */ + public final Object handleCheckOrConversion( + VirtualFrame frame, Object value, ExpressionNode expr) { + var result = check.executeCheckOrConversion(frame, value, expr); + if (result == null) { + throw panicAtTheEnd(value); + } + return result; + } + + /** + * Combines existing type checks into "all of" check. + * + * @param comment description of the check meaning + * @param checks existing type checks + * @return node the composed check or {@code null} if no check is needed + */ + public static TypeCheckValueNode allOf(String comment, TypeCheckValueNode... checks) { + if (checks == null) { + return null; + } + var list = Arrays.asList(checks); + var flatten = + list.stream() + .filter(n -> n != null) + .map(n -> n.check) + .flatMap( + n -> + n instanceof AllOfTypesCheckNode all + ? Arrays.asList(all.getChecks()).stream() + : Stream.of(n)) + .toList(); + var arr = toArray(flatten); + return switch (arr.length) { + case 0 -> null; + case 1 -> new TypeCheckValueNode(arr[0]); + default -> new TypeCheckValueNode(new AllOfTypesCheckNode(comment, arr)); + }; + } + + /** + * Combines existing checks into "one of" check. + * + * @param comment description of the check meaning + * @param checks existing type checks + * @return node the composed check or {@code null} if no check is needed + */ + public static TypeCheckValueNode oneOf(String comment, TypeCheckValueNode... checks) { + if (checks == null) { + return null; + } + var list = Stream.of(checks).filter(n -> n != null).toList(); + return switch (list.size()) { + case 0 -> null; + case 1 -> list.get(0); + default -> { + var abstractTypeCheckList = list.stream().map(n -> n.check).toList(); + var abstractTypeCheckArr = toArray(abstractTypeCheckList); + yield new TypeCheckValueNode(new OneOfTypesCheckNode(comment, abstractTypeCheckArr)); + } + }; + } + + /** + * Constructs "single type" check. + * + * @param comment description of the check meaning + * @param expectedType the type to check for - it shouldn't be {@code Any} + * @return node performing the check + */ + public static TypeCheckValueNode single(String comment, Type expectedType) { + var typeCheckNodeImpl = SingleTypeCheckNodeGen.create(comment, expectedType); + return new TypeCheckValueNode(typeCheckNodeImpl); + } + + /** + * Constructs node to check for {@code polyglot java import} checks. + * + * @param comment description of the check meaning + * @param metaObjectSupplier provider of the meta object to check for + * @return node performing the check + */ + public static TypeCheckValueNode meta( + String comment, Supplier metaObjectSupplier) { + var cachingSupplier = CachingSupplier.wrap(metaObjectSupplier); + var typeCheckNodeImpl = MetaTypeCheckNodeGen.create(comment, cachingSupplier); + return new TypeCheckValueNode(typeCheckNodeImpl); + } + + /** + * Check whether given function is "lazy thunk". E.g. if it is a thunk wrapped by "lazy type + * check". + * + * @param fn function to check + * @return result of the check + */ + public static boolean isWrappedThunk(Function fn) { + if (fn.getSchema() == LazyCheckRootNode.SCHEMA) { + return fn.getPreAppliedArguments()[0] instanceof Function wrappedFn && wrappedFn.isThunk(); + } + return false; + } + + private final PanicException panicAtTheEnd(Object v) { + var expectedTypeMessage = check.getExpectedTypeMessage(); + var ctx = EnsoContext.get(this); + Text msg; + if (v instanceof UnresolvedConstructor) { + msg = Text.create("Cannot find constructor {got} among {exp}"); + } else { + msg = check.getComment(); + } + var err = ctx.getBuiltins().error().makeTypeErrorOfComment(expectedTypeMessage, v, msg); + throw new PanicException(err, this); + } + + private static AbstractTypeCheckNode[] toArray(List list) { + if (list == null) { + return new AbstractTypeCheckNode[0]; + } + var cnt = (int) list.stream().filter(n -> n != null).count(); + var arr = new AbstractTypeCheckNode[cnt]; + var it = list.iterator(); + for (int i = 0; i < cnt; ) { + var element = it.next(); + if (element != null) { + arr[i++] = element; + } + } + return arr; + } +} diff --git a/engine/runtime/src/main/java/org/enso/interpreter/runtime/callable/UnresolvedConstructor.java b/engine/runtime/src/main/java/org/enso/interpreter/runtime/callable/UnresolvedConstructor.java index 1933cb9e27..e046741c0e 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/runtime/callable/UnresolvedConstructor.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/runtime/callable/UnresolvedConstructor.java @@ -188,12 +188,11 @@ public final class UnresolvedConstructor implements EnsoObject { id = callable.getId(); } } - var fn = ReadArgumentNode.build(0, null, null); + var fn = ReadArgumentNode.build(0, null); var args = new CallArgument[prototype.descs.length]; for (var i = 0; i < args.length; i++) { args[i] = - new CallArgument( - prototype.descs[i].getName(), ReadArgumentNode.build(1 + i, null, null)); + new CallArgument(prototype.descs[i].getName(), ReadArgumentNode.build(1 + i, null)); } var expr = ApplicationNode.build(fn, args, DefaultsExecutionMode.EXECUTE); if (id != null) { diff --git a/engine/runtime/src/main/java/org/enso/interpreter/runtime/callable/argument/ArgumentDefinition.java b/engine/runtime/src/main/java/org/enso/interpreter/runtime/callable/argument/ArgumentDefinition.java index b35121f767..d5c53f713e 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/runtime/callable/argument/ArgumentDefinition.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/runtime/callable/argument/ArgumentDefinition.java @@ -2,7 +2,7 @@ package org.enso.interpreter.runtime.callable.argument; import java.util.Optional; import org.enso.interpreter.node.ExpressionNode; -import org.enso.interpreter.node.callable.argument.ReadArgumentCheckNode; +import org.enso.interpreter.node.typecheck.TypeCheckValueNode; /** Tracks the specifics about how arguments are defined at the callable definition site. */ public final class ArgumentDefinition { @@ -26,7 +26,7 @@ public final class ArgumentDefinition { private final int position; private final String name; - private final ReadArgumentCheckNode checkType; + private final TypeCheckValueNode checkType; private final ExpressionNode defaultValue; private final boolean isSuspended; @@ -42,7 +42,7 @@ public final class ArgumentDefinition { public ArgumentDefinition( int position, String name, - ReadArgumentCheckNode checkType, + TypeCheckValueNode checkType, ExpressionNode defaultValue, ExecutionMode executionMode) { this.position = position; @@ -103,7 +103,7 @@ public final class ArgumentDefinition { * * @return {@code null} or list of types to check argument against */ - public ReadArgumentCheckNode getCheckType() { + public TypeCheckValueNode getCheckType() { return checkType; } } diff --git a/engine/runtime/src/main/java/org/enso/interpreter/runtime/data/atom/AtomConstructor.java b/engine/runtime/src/main/java/org/enso/interpreter/runtime/data/atom/AtomConstructor.java index 9a94a4880c..6f0a0a969c 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/runtime/data/atom/AtomConstructor.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/runtime/data/atom/AtomConstructor.java @@ -107,7 +107,7 @@ public final class AtomConstructor implements EnsoObject { EnsoLanguage language, ModuleScope.Builder scopeBuilder, ArgumentDefinition... args) { ExpressionNode[] reads = new ExpressionNode[args.length]; for (int i = 0; i < args.length; i++) { - reads[i] = ReadArgumentNode.build(i, null, null); + reads[i] = ReadArgumentNode.build(i, null); } return initializeFields( language, diff --git a/engine/runtime/src/main/java/org/enso/interpreter/runtime/data/atom/Layout.java b/engine/runtime/src/main/java/org/enso/interpreter/runtime/data/atom/Layout.java index b4cacecc9a..dfe778bff6 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/runtime/data/atom/Layout.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/runtime/data/atom/Layout.java @@ -5,7 +5,7 @@ import com.oracle.truffle.api.dsl.NodeFactory; import com.oracle.truffle.api.nodes.Node; import java.util.List; import org.enso.interpreter.dsl.atom.LayoutSpec; -import org.enso.interpreter.node.callable.argument.ReadArgumentCheckNode; +import org.enso.interpreter.node.typecheck.TypeCheckValueNode; import org.enso.interpreter.runtime.callable.argument.ArgumentDefinition; /** @@ -232,12 +232,12 @@ class Layout { private static final class SetterTypeCheckFactory implements NodeFactory { private final String argName; - private final ReadArgumentCheckNode typeCheck; + private final TypeCheckValueNode typeCheck; private final NodeFactory delegate; private SetterTypeCheckFactory( ArgumentDefinition arg, - ReadArgumentCheckNode typeCheck, + TypeCheckValueNode typeCheck, NodeFactory factory) { assert factory != null; this.argName = arg.getName(); @@ -247,7 +247,7 @@ class Layout { @Override public UnboxingAtom.FieldSetterNode createNode(Object... arguments) { - var checkNode = (ReadArgumentCheckNode) typeCheck.copy(); + var checkNode = (TypeCheckValueNode) typeCheck.copy(); var setterNode = delegate.createNode(arguments); return checkNode == null ? setterNode : new CheckFieldSetterNode(setterNode, checkNode); } @@ -274,18 +274,18 @@ class Layout { } private static final class CheckFieldSetterNode extends UnboxingAtom.FieldSetterNode { - @Child ReadArgumentCheckNode checkNode; + @Child TypeCheckValueNode checkNode; @Child UnboxingAtom.FieldSetterNode setterNode; private CheckFieldSetterNode( - UnboxingAtom.FieldSetterNode setterNode, ReadArgumentCheckNode checkNode) { + UnboxingAtom.FieldSetterNode setterNode, TypeCheckValueNode checkNode) { this.setterNode = setterNode; this.checkNode = checkNode; } @Override public void execute(Atom atom, Object value) { - var valueOrConvertedValue = checkNode.handleCheckOrConversion(null, value); + var valueOrConvertedValue = checkNode.handleCheckOrConversion(null, value, null); setterNode.execute(atom, valueOrConvertedValue); } } diff --git a/engine/runtime/src/main/java/org/enso/interpreter/runtime/data/atom/SuspendedFieldGetterNode.java b/engine/runtime/src/main/java/org/enso/interpreter/runtime/data/atom/SuspendedFieldGetterNode.java index 0d2401feab..3e068d89d9 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/runtime/data/atom/SuspendedFieldGetterNode.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/runtime/data/atom/SuspendedFieldGetterNode.java @@ -6,8 +6,8 @@ import com.oracle.truffle.api.nodes.Node; import com.oracle.truffle.api.profiles.BranchProfile; import java.util.List; import org.enso.interpreter.node.callable.InvokeCallableNode; -import org.enso.interpreter.node.callable.argument.ReadArgumentCheckNode; import org.enso.interpreter.node.callable.dispatch.InvokeFunctionNode; +import org.enso.interpreter.node.typecheck.TypeCheckValueNode; import org.enso.interpreter.runtime.callable.argument.CallArgumentInfo; import org.enso.interpreter.runtime.callable.function.Function; import org.enso.interpreter.runtime.data.atom.UnboxingAtom.FieldGetterNode; @@ -86,7 +86,7 @@ final class SuspendedFieldGetterNode extends UnboxingAtom.FieldGetterNode { } private static boolean shallBeExtracted(Function fn) { - return fn.isThunk() || ReadArgumentCheckNode.isWrappedThunk(fn); + return fn.isThunk() || TypeCheckValueNode.isWrappedThunk(fn); } @Override diff --git a/engine/runtime/src/main/scala/org/enso/interpreter/runtime/IrToTruffle.scala b/engine/runtime/src/main/scala/org/enso/interpreter/runtime/IrToTruffle.scala index d92326eee8..a606cbcf7d 100644 --- a/engine/runtime/src/main/scala/org/enso/interpreter/runtime/IrToTruffle.scala +++ b/engine/runtime/src/main/scala/org/enso/interpreter/runtime/IrToTruffle.scala @@ -67,7 +67,7 @@ import org.enso.compiler.pass.resolve.{ TypeSignatures } import org.enso.interpreter.node.callable.argument.ReadArgumentNode -import org.enso.interpreter.node.callable.argument.ReadArgumentCheckNode +import org.enso.interpreter.node.typecheck.TypeCheckValueNode import org.enso.interpreter.node.callable.function.{ BlockNode, CreateFunctionNode @@ -306,12 +306,12 @@ class IrToTruffle( .asInstanceOf[FramePointer] val slotIdx = fp.frameSlotIdx() argDefs(idx) = arg - val readArg = + val readArgNoCheck = ReadArgumentNode.build( idx, - arg.getDefaultValue.orElse(null), - checkNode + arg.getDefaultValue.orElse(null) ) + val readArg = TypeCheckValueNode.wrap(readArgNoCheck, checkNode) val assignmentArg = AssignmentNode.build(readArg, slotIdx) val argRead = ReadLocalVariableNode.build(new FramePointer(0, slotIdx)) @@ -806,7 +806,7 @@ class IrToTruffle( methodDef.methodName.name, fn.arguments, fn.body, - ReadArgumentCheckNode.build(context, "conversion", toType), + TypeCheckValueNode.single("conversion", toType), None, true ) @@ -847,27 +847,24 @@ class IrToTruffle( private def extractAscribedType( comment: String, t: Expression - ): ReadArgumentCheckNode = t match { + ): TypeCheckValueNode = t match { case u: `type`.Set.Union => val oneOf = u.operands.map(extractAscribedType(comment, _)) if (oneOf.contains(null)) { null } else { - ReadArgumentCheckNode.oneOf( - comment, - oneOf.asJava - ) + val arr: Array[TypeCheckValueNode] = oneOf.toArray + TypeCheckValueNode.oneOf(comment, arr: _*) } case i: `type`.Set.Intersection => - ReadArgumentCheckNode.allOf( + TypeCheckValueNode.allOf( comment, extractAscribedType(comment, i.left), extractAscribedType(comment, i.right) ) case p: Application.Prefix => extractAscribedType(comment, p.function) case _: Tpe.Function => - ReadArgumentCheckNode.build( - context, + TypeCheckValueNode.single( comment, context.getTopScope().getBuiltins().function() ) @@ -888,13 +885,13 @@ class IrToTruffle( if (context.getBuiltins().any() == typeOrAny) { null } else { - ReadArgumentCheckNode.build(context, comment, typeOrAny) + TypeCheckValueNode.single(comment, typeOrAny) } case Some( BindingsMap .Resolution(BindingsMap.ResolvedPolyglotSymbol(mod, symbol)) ) => - ReadArgumentCheckNode.meta( + TypeCheckValueNode.meta( comment, asScope( mod.unsafeAsModule().asInstanceOf[TruffleCompilerContext.Module] @@ -907,7 +904,7 @@ class IrToTruffle( private def checkAsTypes( arg: DefinitionArgument - ): ReadArgumentCheckNode = { + ): TypeCheckValueNode = { val comment = "`" + arg.name.name + "`" arg.ascribedType.map(extractAscribedType(comment, _)).getOrElse(null) } @@ -1312,7 +1309,7 @@ class IrToTruffle( extractAscribedType(asc.comment.orNull, asc.signature) if (checkNode != null) { val body = run(asc.typed, binding, subjectToInstrumentation) - ReadArgumentCheckNode.wrap(body, checkNode) + TypeCheckValueNode.wrap(body, checkNode) } else { processType(asc) } @@ -1341,7 +1338,7 @@ class IrToTruffle( extractAscribedType(tpe.comment.orNull, tpe.signature) if (checkNode != null) { runtimeExpression = - ReadArgumentCheckNode.wrap(runtimeExpression, checkNode) + TypeCheckValueNode.wrap(runtimeExpression, checkNode) } } } @@ -2192,7 +2189,7 @@ class IrToTruffle( val initialName: String, val arguments: List[DefinitionArgument], val body: Expression, - val typeCheck: ReadArgumentCheckNode, + val typeCheck: TypeCheckValueNode, val effectContext: Option[String], val subjectToInstrumentation: Boolean ) { @@ -2229,7 +2226,7 @@ class IrToTruffle( if (typeCheck == null) { (argExpressions.toArray, bodyExpr) } else { - val bodyWithCheck = ReadArgumentCheckNode.wrap(bodyExpr, typeCheck) + val bodyWithCheck = TypeCheckValueNode.wrap(bodyExpr, typeCheck) (argExpressions.toArray, bodyWithCheck) } } @@ -2255,12 +2252,12 @@ class IrToTruffle( ) .asInstanceOf[FramePointer] val slotIdx = fp.frameSlotIdx() - val readArg = + val readArgNoCheck = ReadArgumentNode.build( idx, - arg.getDefaultValue.orElse(null), - checkNode + arg.getDefaultValue.orElse(null) ) + val readArg = TypeCheckValueNode.wrap(readArgNoCheck, checkNode) val assignArg = AssignmentNode.build(readArg, slotIdx) argExpressions.append(assignArg) @@ -2589,7 +2586,7 @@ class IrToTruffle( def run( inputArg: DefinitionArgument, position: Int, - types: ReadArgumentCheckNode + types: TypeCheckValueNode ): ArgumentDefinition = inputArg match { case arg: DefinitionArgument.Specified => diff --git a/test/Base_Tests/src/Data/Complex.enso b/test/Base_Tests/src/Data/Complex.enso new file mode 100644 index 0000000000..9a0c98590b --- /dev/null +++ b/test/Base_Tests/src/Data/Complex.enso @@ -0,0 +1,33 @@ +import Standard.Base.Data.Numbers.Float +import Standard.Base.Data.Numbers.Number +import Standard.Base.Data.Ordering.Comparable +import Standard.Base.Data.Ordering.Ordering +import project.Data.Complex_Helpers +import project.Data.Complex_Helpers.Complex_Comparator + +## Sample definition of a complex number with conversions + from Number and implementation of a comparator. +type Complex + private Value re:Float im:Float + + new re=0:Float im=0:Float = + c = Complex.Value re im + if im != 0 then c:Complex else + c.as_complex_and_float + + + self (that:Complex) = Complex.new self.re+that.re self.im+that.im + + < self (that:Complex) = Complex_Comparator.compare self that == Ordering.Less + > self (that:Complex) = Complex_Comparator.compare self that == Ordering.Greater + >= self (that:Complex) = + ordering = Complex_Comparator.compare self that + ordering == Ordering.Greater || ordering == Ordering.Equal + <= self (that:Complex) = + ordering = Complex_Comparator.compare self that + ordering == Ordering.Less || ordering == Ordering.Equal + +Complex.from (that:Number) = Complex.new that + + +Comparable.from (that:Complex) = Comparable.new that Complex_Comparator +Comparable.from (that:Number) = Comparable.new that Complex_Comparator diff --git a/test/Base_Tests/src/Data/Complex_Helpers.enso b/test/Base_Tests/src/Data/Complex_Helpers.enso new file mode 100644 index 0000000000..ddbc408a5e --- /dev/null +++ b/test/Base_Tests/src/Data/Complex_Helpers.enso @@ -0,0 +1,25 @@ +private + +import Standard.Base.Nothing +import Standard.Base.Data.Numbers.Float +import Standard.Base.Data.Ordering.Ordering +import Standard.Base.Error.Error +import Standard.Base.Errors.Illegal_Argument.Illegal_Argument +import project.Data.Complex.Complex + +type Complex_Comparator + compare x:Complex y:Complex = if x.re==y.re && x.im==y.im then Ordering.Equal else + if x.im==0 && y.im==0 then Ordering.compare x.re y.re else + Nothing + hash x:Complex = if x.im == 0 then Ordering.hash x.re else + 7*x.re + 11*x.im + +## uses the explicit conversion defined in this private module +Complex.as_complex_and_float self = + self : Complex&Float + +## explicit "conversion" of `Complex` to `Float` in a private module + used in `as_complex_and_float` +Float.from (that:Complex) = + if that.im == 0 then that.re else + Error.throw <| Illegal_Argument.Error "Cannot convert Complex with imaginary part to Float" diff --git a/test/Base_Tests/src/Data/Numbers_Spec.enso b/test/Base_Tests/src/Data/Numbers_Spec.enso index 4a66dc988a..1d8f654d12 100644 --- a/test/Base_Tests/src/Data/Numbers_Spec.enso +++ b/test/Base_Tests/src/Data/Numbers_Spec.enso @@ -9,6 +9,7 @@ from Standard.Base.Data.Numbers import Number_Parse_Error from Standard.Test import all import project.Data.Round_Spec +import project.Data.Complex.Complex polyglot java import java.math.BigInteger polyglot java import java.math.BigDecimal @@ -18,34 +19,6 @@ Integer.is_even self = self % 2 == 0 Float.get_fun_factor self = "Wow, " + self.to_text + " is such a fun number!" -type Complex - Value re:Float im:Float - - new re=0 im=0 = Complex.Value re im - - + self (that:Complex) = Complex.new self.re+that.re self.im+that.im - - < self (that:Complex) = Complex_Comparator.compare self that == Ordering.Less - > self (that:Complex) = Complex_Comparator.compare self that == Ordering.Greater - >= self (that:Complex) = - ordering = Complex_Comparator.compare self that - ordering == Ordering.Greater || ordering == Ordering.Equal - <= self (that:Complex) = - ordering = Complex_Comparator.compare self that - ordering == Ordering.Less || ordering == Ordering.Equal - -Complex.from (that:Number) = Complex.new that - -type Complex_Comparator - compare x:Complex y:Complex = if x.re==y.re && x.im==y.im then Ordering.Equal else - if x.im==0 && y.im==0 then Ordering.compare x.re y.re else - Nothing - hash x:Complex = if x.im == 0 then Ordering.hash x.re else - 7*x.re + 11*x.im - -Comparable.from (that:Complex) = Comparable.new that Complex_Comparator -Comparable.from (that:Number) = Comparable.new that Complex_Comparator - float_id : Float -> Float float_id f = Number_Utils.floatId f diff --git a/test/Base_Tests/src/Main.enso b/test/Base_Tests/src/Main.enso index 9231667f14..3b6b71a1be 100644 --- a/test/Base_Tests/src/Main.enso +++ b/test/Base_Tests/src/Main.enso @@ -11,6 +11,7 @@ import project.Semantic.Import_Loop.Spec as Import_Loop_Spec import project.Semantic.Meta_Spec import project.Semantic.Instrumentor_Spec import project.Semantic.Meta_Location_Spec +import project.Semantic.Multi_Value_Spec import project.Semantic.Names_Spec import project.Semantic.Equals_Spec import project.Semantic.Runtime_Spec @@ -132,6 +133,7 @@ main filter=Nothing = Meta_Spec.add_specs suite_builder Instrumentor_Spec.add_specs suite_builder Meta_Location_Spec.add_specs suite_builder + Multi_Value_Spec.add_specs suite_builder Names_Spec.add_specs suite_builder Numbers_Spec.add_specs suite_builder Equals_Spec.add_specs suite_builder diff --git a/test/Base_Tests/src/Semantic/Multi_Value_Spec.enso b/test/Base_Tests/src/Semantic/Multi_Value_Spec.enso new file mode 100644 index 0000000000..5375822a8f --- /dev/null +++ b/test/Base_Tests/src/Semantic/Multi_Value_Spec.enso @@ -0,0 +1,20 @@ +from Standard.Base import all +from Standard.Test import all +import Standard.Base.Errors.Common.Type_Error + +import project.Data.Complex.Complex + +add_specs suite_builder = + suite_builder.group "Complex Multi Value" group_builder-> + group_builder.specify "Cannot convert to Float if it has imaginary part" <| + c = Complex.new 1 5 + Test.expect_panic Type_Error (c:Float) + group_builder.specify "Represents both Complex & Float with only real part" <| + c = Complex.new 1.5 0.0 + (c:Complex).re . should_equal 1.5 + (c:Float) . should_equal 1.5 + +main filter=Nothing = + suite = Test.build suite_builder-> + add_specs suite_builder + suite.run_with_filter filter