Delay check of suspended arguments until they are about to be computed (#7727)

This commit is contained in:
Jaroslav Tulach 2023-09-06 09:34:12 +02:00 committed by GitHub
parent 1f8675a031
commit ab1c1a4c12
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 235 additions and 26 deletions

View File

@ -1,21 +1,27 @@
package org.enso.interpreter.node.callable.argument; package org.enso.interpreter.node.callable.argument;
import com.oracle.truffle.api.dsl.Cached.Shared;
import java.util.Arrays; import java.util.Arrays;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import org.enso.interpreter.EnsoLanguage;
import org.enso.interpreter.node.BaseNode.TailStatus;
import org.enso.interpreter.node.EnsoRootNode; import org.enso.interpreter.node.EnsoRootNode;
import org.enso.interpreter.node.callable.ApplicationNode; import org.enso.interpreter.node.callable.ApplicationNode;
import org.enso.interpreter.node.callable.InvokeCallableNode.DefaultsExecutionMode; import org.enso.interpreter.node.callable.InvokeCallableNode.DefaultsExecutionMode;
import org.enso.interpreter.node.callable.thunk.ThunkExecutorNode;
import org.enso.interpreter.node.expression.builtin.meta.AtomWithAHoleNode; import org.enso.interpreter.node.expression.builtin.meta.AtomWithAHoleNode;
import org.enso.interpreter.node.expression.builtin.meta.IsValueOfTypeNode; import org.enso.interpreter.node.expression.builtin.meta.IsValueOfTypeNode;
import org.enso.interpreter.node.expression.builtin.meta.TypeOfNode; import org.enso.interpreter.node.expression.builtin.meta.TypeOfNode;
import org.enso.interpreter.node.expression.literal.LiteralNode; import org.enso.interpreter.node.expression.literal.LiteralNode;
import org.enso.interpreter.runtime.EnsoContext; import org.enso.interpreter.runtime.EnsoContext;
import org.enso.interpreter.runtime.callable.Annotation;
import org.enso.interpreter.runtime.callable.UnresolvedConversion; import org.enso.interpreter.runtime.callable.UnresolvedConversion;
import org.enso.interpreter.runtime.callable.argument.ArgumentDefinition;
import org.enso.interpreter.runtime.callable.argument.ArgumentDefinition.ExecutionMode;
import org.enso.interpreter.runtime.callable.argument.CallArgument; import org.enso.interpreter.runtime.callable.argument.CallArgument;
import org.enso.interpreter.runtime.callable.argument.CallArgumentInfo;
import org.enso.interpreter.runtime.callable.function.Function; import org.enso.interpreter.runtime.callable.function.Function;
import org.enso.interpreter.runtime.callable.function.FunctionSchema;
import org.enso.interpreter.runtime.data.Type; import org.enso.interpreter.runtime.data.Type;
import org.enso.interpreter.runtime.error.DataflowError; import org.enso.interpreter.runtime.error.DataflowError;
import org.enso.interpreter.runtime.error.PanicException; import org.enso.interpreter.runtime.error.PanicException;
@ -23,13 +29,16 @@ import org.graalvm.collections.Pair;
import com.oracle.truffle.api.CompilerAsserts; import com.oracle.truffle.api.CompilerAsserts;
import com.oracle.truffle.api.CompilerDirectives; import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.TruffleLanguage;
import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.Cached.Shared;
import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.frame.MaterializedFrame; import com.oracle.truffle.api.frame.MaterializedFrame;
import com.oracle.truffle.api.frame.VirtualFrame; import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.nodes.ExplodeLoop; import com.oracle.truffle.api.nodes.ExplodeLoop;
import com.oracle.truffle.api.nodes.InvalidAssumptionException; import com.oracle.truffle.api.nodes.InvalidAssumptionException;
import com.oracle.truffle.api.nodes.Node; import com.oracle.truffle.api.nodes.Node;
import com.oracle.truffle.api.nodes.RootNode;
public abstract class ReadArgumentCheckNode extends Node { public abstract class ReadArgumentCheckNode extends Node {
private final String name; private final String name;
@ -38,6 +47,8 @@ public abstract class ReadArgumentCheckNode extends Node {
private final Type[] expectedTypes; private final Type[] expectedTypes;
@CompilerDirectives.CompilationFinal @CompilerDirectives.CompilationFinal
private String expectedTypeMessage; private String expectedTypeMessage;
@CompilerDirectives.CompilationFinal
private LazyCheckRootNode lazyCheck;
ReadArgumentCheckNode(String name, Type[] expectedTypes) { ReadArgumentCheckNode(String name, Type[] expectedTypes) {
this.name = name; this.name = name;
@ -55,10 +66,18 @@ public abstract class ReadArgumentCheckNode extends Node {
public abstract Object executeCheckOrConversion(VirtualFrame frame, Object value); public abstract Object executeCheckOrConversion(VirtualFrame frame, Object value);
public static boolean isWrappedThunk(Function fn) {
if (fn.getSchema() == LazyCheckRootNode.SCHEMA) {
return fn.getPreAppliedArguments()[0] instanceof Function wrappedFn && wrappedFn.isThunk();
}
return false;
}
@Specialization(rewriteOn = InvalidAssumptionException.class) @Specialization(rewriteOn = InvalidAssumptionException.class)
Object doCheckNoConversionNeeded(VirtualFrame frame, Object v) throws InvalidAssumptionException { Object doCheckNoConversionNeeded(VirtualFrame frame, Object v) throws InvalidAssumptionException {
if (findAmongTypes(v)) { var ret = findAmongTypes(v);
return v; if (ret != null) {
return ret;
} else { } else {
throw new InvalidAssumptionException(); throw new InvalidAssumptionException();
} }
@ -83,26 +102,34 @@ public abstract class ReadArgumentCheckNode extends Node {
@Shared("typeOfNode") @Cached TypeOfNode typeOfNode @Shared("typeOfNode") @Cached TypeOfNode typeOfNode
) { ) {
var type = findType(typeOfNode, v); var type = findType(typeOfNode, v);
return doWithConversionUncachedBoundary(frame.materialize(), v, type); return doWithConversionUncachedBoundary(frame == null ? null : frame.materialize(), v, type);
} }
private static boolean isAllFitValue(Object v) { private static boolean isAllFitValue(Object v) {
return v instanceof DataflowError return v instanceof DataflowError || AtomWithAHoleNode.isHole(v);
|| (v instanceof Function fn && fn.isThunk())
|| AtomWithAHoleNode.isHole(v);
} }
@ExplodeLoop @ExplodeLoop
private boolean findAmongTypes(Object v) { private Object findAmongTypes(Object v) {
if (isAllFitValue(v)) { if (isAllFitValue(v)) {
return true; return v;
}
if (v instanceof Function fn && fn.isThunk()) {
if (lazyCheck == null) {
CompilerDirectives.transferToInterpreter();
var enso = EnsoLanguage.get(this);
var node = (ReadArgumentCheckNode) copy();
lazyCheck = new LazyCheckRootNode(enso, node);
}
var lazyCheckFn = lazyCheck.wrapThunk(fn);
return lazyCheckFn;
} }
for (Type t : expectedTypes) { for (Type t : expectedTypes) {
if (checkType.execute(t, v)) { if (checkType.execute(t, v)) {
return true; return v;
} }
} }
return false; return null;
} }
@ExplodeLoop @ExplodeLoop
@ -149,8 +176,9 @@ public abstract class ReadArgumentCheckNode extends Node {
VirtualFrame frame, Object v, ApplicationNode convertNode VirtualFrame frame, Object v, ApplicationNode convertNode
) throws PanicException { ) throws PanicException {
if (convertNode == null) { if (convertNode == null) {
if (findAmongTypes(v)) { var ret = findAmongTypes(v);
return v; if (ret != null) {
return ret;
} }
throw panicAtTheEnd(v); throw panicAtTheEnd(v);
} else { } else {
@ -181,4 +209,40 @@ public abstract class ReadArgumentCheckNode extends Node {
Arrays.stream(expectedTypes).map(Type::toString).collect(Collectors.joining(" | ")); Arrays.stream(expectedTypes).map(Type::toString).collect(Collectors.joining(" | "));
return expectedTypeMessage; return expectedTypeMessage;
} }
private static final class LazyCheckRootNode extends RootNode {
@Child
private ThunkExecutorNode evalThunk;
@Child
private ReadArgumentCheckNode check;
static final FunctionSchema SCHEMA = new FunctionSchema(
FunctionSchema.CallerFrameAccess.NONE,
new ArgumentDefinition[] { new ArgumentDefinition(0, "delegate", null, null, ExecutionMode.EXECUTE) },
new boolean[] { true },
new CallArgumentInfo[0],
new Annotation[0]
);
LazyCheckRootNode(TruffleLanguage<?> language, ReadArgumentCheckNode check) {
super(language);
this.check = check;
this.evalThunk = ThunkExecutorNode.build();
}
Function wrapThunk(Function thunk) {
return new Function(getCallTarget(), thunk.getScope(), SCHEMA, new Object[] { thunk }, null);
}
@Override
public Object execute(VirtualFrame frame) {
var state = Function.ArgumentsHelper.getState(frame.getArguments());
var args = Function.ArgumentsHelper.getPositionalArguments(frame.getArguments());
assert args.length == 1;
assert args[0] instanceof Function fn && fn.isThunk();
var raw = evalThunk.executeThunk(frame, args[0], state, TailStatus.NOT_TAIL);
var result = check.executeCheckOrConversion(frame, raw);
return result;
}
}
} }

View File

@ -4,12 +4,16 @@ import com.oracle.truffle.api.exception.AbstractTruffleException;
import com.oracle.truffle.api.nodes.Node; import com.oracle.truffle.api.nodes.Node;
import org.enso.interpreter.node.callable.InvokeCallableNode; import org.enso.interpreter.node.callable.InvokeCallableNode;
import org.enso.interpreter.node.callable.argument.ReadArgumentCheckNode;
import org.enso.interpreter.node.callable.dispatch.InvokeFunctionNode; import org.enso.interpreter.node.callable.dispatch.InvokeFunctionNode;
import org.enso.interpreter.runtime.EnsoContext; import org.enso.interpreter.runtime.EnsoContext;
import org.enso.interpreter.runtime.callable.argument.CallArgumentInfo; import org.enso.interpreter.runtime.callable.argument.CallArgumentInfo;
import org.enso.interpreter.runtime.callable.atom.Atom; import org.enso.interpreter.runtime.callable.atom.Atom;
import org.enso.interpreter.runtime.callable.function.Function; import org.enso.interpreter.runtime.callable.function.Function;
import org.enso.interpreter.runtime.data.EnsoObject; import org.enso.interpreter.runtime.data.EnsoObject;
import org.enso.interpreter.runtime.data.text.Text;
import org.enso.interpreter.runtime.error.DataflowError;
import org.enso.interpreter.runtime.error.PanicException;
import org.enso.interpreter.runtime.state.State; import org.enso.interpreter.runtime.state.State;
/** /**
@ -35,32 +39,28 @@ final class SuspendedFieldGetterNode extends UnboxingAtom.FieldGetterNode {
return new SuspendedFieldGetterNode(get, set); return new SuspendedFieldGetterNode(get, set);
} }
private static boolean shallBeExtracted(Function fn) {
return fn.isThunk() || ReadArgumentCheckNode.isWrappedThunk(fn);
}
@Override @Override
public Object execute(Atom atom) { public Object execute(Atom atom) {
java.lang.Object value = get.execute(atom); java.lang.Object value = get.execute(atom);
if (value instanceof Function fn && fn.isThunk()) { if (value instanceof Function fn && shallBeExtracted(fn)) {
try { try {
org.enso.interpreter.runtime.EnsoContext ctx = EnsoContext.get(this); org.enso.interpreter.runtime.EnsoContext ctx = EnsoContext.get(this);
java.lang.Object newValue = invoke.execute(fn, null, State.create(ctx), new Object[0]); java.lang.Object newValue = invoke.execute(fn, null, State.create(ctx), new Object[0]);
set.execute(atom, newValue); set.execute(atom, newValue);
return newValue; return newValue;
} catch (AbstractTruffleException ex) { } catch (AbstractTruffleException ex) {
var rethrow = new SuspendedException(ex); var rethrow = DataflowError.withTrace(ex, ex);
set.execute(atom, rethrow); set.execute(atom, rethrow);
throw ex; throw ex;
} }
} else if (value instanceof SuspendedException suspended) { } else if (value instanceof DataflowError suspended && suspended.getPayload() instanceof AbstractTruffleException ex) {
throw suspended.ex; throw ex;
} else { } else {
return value; return value;
} }
} }
private static final class SuspendedException implements EnsoObject {
final AbstractTruffleException ex;
SuspendedException(AbstractTruffleException ex) {
this.ex = ex;
}
}
} }

View File

@ -2,6 +2,7 @@ package org.enso.interpreter.test;
import java.net.URI; import java.net.URI;
import java.net.URISyntaxException; import java.net.URISyntaxException;
import java.util.ArrayList;
import org.graalvm.polyglot.Context; import org.graalvm.polyglot.Context;
import org.graalvm.polyglot.PolyglotException; import org.graalvm.polyglot.PolyglotException;
@ -9,6 +10,7 @@ import org.graalvm.polyglot.Source;
import org.graalvm.polyglot.Value; import org.graalvm.polyglot.Value;
import org.junit.AfterClass; import org.junit.AfterClass;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
import org.junit.BeforeClass; import org.junit.BeforeClass;
@ -96,6 +98,149 @@ public class SignatureTest extends TestBase {
assertTrue("Yields Error value", yieldsError.isException()); assertTrue("Yields Error value", yieldsError.isException());
} }
@Test
public void lazyIntegerInConstructor() throws Exception {
final URI uri = new URI("memory://int_simple_complex.enso");
final Source src = Source.newBuilder("enso", """
from Standard.Base import all
type Int
Simple v
Complex (~unwrap : Int)
value self = case self of
Int.Simple v -> v
Int.Complex unwrap -> unwrap.value
+ self (that:Int) = Int.Simple self.value+that.value
simple v = Int.Simple v
complex x y = Int.Complex (x+y)
""", uri.getHost())
.uri(uri)
.buildLiteral();
var module = ctx.eval(src);
var simple = module.invokeMember("eval_expression", "simple");
var complex = module.invokeMember("eval_expression", "complex");
var six = simple.execute(6);
var seven = simple.execute(7);
var some13 = complex.execute(six, seven);
var thirteen = some13.invokeMember("value");
assertNotNull("member found", thirteen);
assertEquals(13, thirteen.asInt());
var someHello = complex.execute("Hello", "World");
try {
var error = someHello.invokeMember("value");
fail("not expecting any value: " + error);
} catch (PolyglotException e) {
assertTypeError("`unwrap`", "Int", "Text", e.getMessage());
}
try {
var secondError = someHello.invokeMember("value");
fail("not expecting any value again: " + secondError);
} catch (PolyglotException e) {
assertTypeError("`unwrap`", "Int", "Text", e.getMessage());
}
}
@Test
public void runtimeCheckOfLazyAscribedFunctionSignature() throws Exception {
final URI uri = new URI("memory://neg_lazy.enso");
final Source src = Source.newBuilder("enso", """
from Standard.Base import Integer, IO
build (~zero : Integer) =
neg (~a : Integer) = zero - a
neg
make arr = build <|
arr.at 0
""", uri.getHost())
.uri(uri)
.buildLiteral();
var module = ctx.eval(src);
var zeroValue = new Object[] { 0 };
var neg = module.invokeMember("eval_expression", "make").execute((Object)zeroValue);
zeroValue[0] = "Wrong";
try {
var error = neg.execute(-5);
fail("Expecting an error: " + error);
} catch (PolyglotException ex) {
assertTypeError("`zero`", "Integer", "Text", ex.getMessage());
}
zeroValue[0] = 0;
var five = neg.execute(-5);
assertEquals("Five", 5, five.asInt());
try {
var res = neg.execute("Hi");
fail("Expecting an exception, not: " + res);
} catch (PolyglotException e) {
assertTypeError("`a`", "Integer", "Text", e.getMessage());
}
zeroValue[0] = 5;
var fifteen = neg.execute(-10);
assertEquals("Five + Ten as the zeroValue[0] is always read again", 15, fifteen.asInt());
zeroValue[0] = 0;
var ten = neg.execute(-10);
assertEquals("Just ten as the zeroValue[0] is always read again", 10, ten.asInt());
}
@Test
public void runtimeCheckOfLazyAscribedConstructorSignature() throws Exception {
final URI uri = new URI("memory://neg_lazy_const.enso");
final Source src = Source.newBuilder("enso", """
from Standard.Base import Integer, IO, Polyglot
type Lazy
Value (~zero : Integer)
neg self (~a : Integer) = self.zero - a
make arr = Lazy.Value <|
Polyglot.invoke arr "add" [ arr.length ]
arr.at 0
""", uri.getHost())
.uri(uri)
.buildLiteral();
var module = ctx.eval(src);
var zeroValue = new ArrayList<Integer>();
zeroValue.add(0);
var lazy = module.invokeMember("eval_expression", "make").execute((Object)zeroValue);
assertEquals("No read from zeroValue, still size 1", 1, zeroValue.size());
var five = lazy.invokeMember("neg", -5);
assertEquals("Five", 5, five.asInt());
assertEquals("One read from zeroValue, size 2", 2, zeroValue.size());
try {
var res = lazy.invokeMember("neg", "Hi");
fail("Expecting an exception, not: " + res);
} catch (PolyglotException e) {
assertTypeError("`a`", "Integer", "Text", e.getMessage());
}
zeroValue.set(0, 5);
var fifteen = lazy.invokeMember("neg", -10);
assertEquals("Five + Ten as the zeroValue[0] is never read again", 10, fifteen.asInt());
assertEquals("One read from zeroValue, size 2", 2, zeroValue.size());
zeroValue.set(0, 0);
var ten = lazy.invokeMember("neg", -9);
assertEquals("Just nine as the zeroValue[0] is always read again", 9, ten.asInt());
assertEquals("One read from zeroValue, size 2", 2, zeroValue.size());
}
@Test @Test
public void runtimeCheckOfAscribedInstanceMethodSignature() throws Exception { public void runtimeCheckOfAscribedInstanceMethodSignature() throws Exception {
final URI uri = new URI("memory://twice_instance.enso"); final URI uri = new URI("memory://twice_instance.enso");