mirror of
https://github.com/enso-org/enso.git
synced 2024-11-22 22:10:15 +03:00
Improve TCO in the presence of warnings (#7116)
Partially revert https://github.com/enso-org/enso/pull/6849, which introduced a regression in TCO in the presence of warnings. Rather than modifying the tail call status, `TailCallException` now propagates the extracted warnings and appends them to the final result. Closes #7093 # Important Notes Compared to the previous attempt we don't pay the penalty of adding the warnings or even checking for them because it is being dealt in a separate specialization.
This commit is contained in:
parent
c4f19e7d66
commit
ae4666c4d3
@ -21,6 +21,7 @@ import org.enso.interpreter.runtime.callable.argument.CallArgumentInfo;
|
||||
import org.enso.interpreter.runtime.callable.atom.Atom;
|
||||
import org.enso.interpreter.runtime.callable.atom.AtomConstructor;
|
||||
import org.enso.interpreter.runtime.callable.function.Function;
|
||||
import org.enso.interpreter.runtime.control.TailCallException;
|
||||
import org.enso.interpreter.runtime.error.DataflowError;
|
||||
import org.enso.interpreter.runtime.error.PanicException;
|
||||
import org.enso.interpreter.runtime.error.PanicSentinel;
|
||||
@ -264,6 +265,15 @@ public abstract class InvokeCallableNode extends BaseNode {
|
||||
State state,
|
||||
Object[] arguments,
|
||||
@CachedLibrary(limit = "3") WarningsLibrary warnings) {
|
||||
|
||||
Warning[] extracted;
|
||||
Object callable;
|
||||
try {
|
||||
extracted = warnings.getWarnings(warning, null);
|
||||
callable = warnings.removeWarnings(warning);
|
||||
} catch (UnsupportedMessageException e) {
|
||||
throw CompilerDirectives.shouldNotReachHere(e);
|
||||
}
|
||||
try {
|
||||
if (childDispatch == null) {
|
||||
CompilerDirectives.transferToInterpreterAndInvalidate();
|
||||
@ -277,7 +287,7 @@ public abstract class InvokeCallableNode extends BaseNode {
|
||||
invokeFunctionNode.getSchema(),
|
||||
invokeFunctionNode.getDefaultsExecutionMode(),
|
||||
invokeFunctionNode.getArgumentsExecutionMode()));
|
||||
childDispatch.setTailStatus(TailStatus.NOT_TAIL);
|
||||
childDispatch.setTailStatus(getTailStatus());
|
||||
childDispatch.setId(invokeFunctionNode.getId());
|
||||
notifyInserted(childDispatch);
|
||||
}
|
||||
@ -287,12 +297,12 @@ public abstract class InvokeCallableNode extends BaseNode {
|
||||
}
|
||||
|
||||
var result = childDispatch.execute(
|
||||
warnings.removeWarnings(warning),
|
||||
callable,
|
||||
callerFrame,
|
||||
state,
|
||||
arguments);
|
||||
|
||||
Warning[] extracted = warnings.getWarnings(warning, null);
|
||||
|
||||
if (result instanceof DataflowError) {
|
||||
return result;
|
||||
} else if (result instanceof WithWarnings withWarnings) {
|
||||
@ -300,8 +310,8 @@ public abstract class InvokeCallableNode extends BaseNode {
|
||||
} else {
|
||||
return WithWarnings.wrap(EnsoContext.get(this), result, extracted);
|
||||
}
|
||||
} catch (UnsupportedMessageException e) {
|
||||
throw CompilerDirectives.shouldNotReachHere(e);
|
||||
} catch (TailCallException e) {
|
||||
throw new TailCallException(e, extracted);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -17,6 +17,7 @@ import org.enso.interpreter.runtime.EnsoContext;
|
||||
import org.enso.interpreter.runtime.callable.UnresolvedConversion;
|
||||
import org.enso.interpreter.runtime.callable.argument.CallArgumentInfo;
|
||||
import org.enso.interpreter.runtime.callable.function.Function;
|
||||
import org.enso.interpreter.runtime.control.TailCallException;
|
||||
import org.enso.interpreter.runtime.data.ArrayRope;
|
||||
import org.enso.interpreter.runtime.data.Type;
|
||||
import org.enso.interpreter.runtime.data.text.Text;
|
||||
@ -162,7 +163,7 @@ public abstract class InvokeConversionNode extends BaseNode {
|
||||
invokeFunctionNode.getDefaultsExecutionMode(),
|
||||
invokeFunctionNode.getArgumentsExecutionMode(),
|
||||
thatArgumentPosition));
|
||||
childDispatch.setTailStatus(TailStatus.NOT_TAIL);
|
||||
childDispatch.setTailStatus(getTailStatus());
|
||||
childDispatch.setId(invokeFunctionNode.getId());
|
||||
notifyInserted(childDispatch);
|
||||
}
|
||||
@ -170,11 +171,15 @@ public abstract class InvokeConversionNode extends BaseNode {
|
||||
lock.unlock();
|
||||
}
|
||||
}
|
||||
arguments[thatArgumentPosition] = that.getValue();
|
||||
Object value = that.getValue();
|
||||
arguments[thatArgumentPosition] = value;
|
||||
ArrayRope<Warning> warnings = that.getReassignedWarningsAsRope(this);
|
||||
Object result =
|
||||
childDispatch.execute(frame, state, conversion, self, that.getValue(), arguments);
|
||||
try {
|
||||
Object result = childDispatch.execute(frame, state, conversion, self, value, arguments);
|
||||
return WithWarnings.appendTo(EnsoContext.get(this), result, warnings);
|
||||
} catch (TailCallException e) {
|
||||
throw new TailCallException(e, warnings.toArray(Warning[]::new));
|
||||
}
|
||||
}
|
||||
|
||||
@Specialization(guards = "interop.isString(that)")
|
||||
|
@ -40,6 +40,7 @@ import org.enso.interpreter.runtime.callable.argument.ArgumentDefinition;
|
||||
import org.enso.interpreter.runtime.callable.argument.CallArgumentInfo;
|
||||
import org.enso.interpreter.runtime.callable.function.Function;
|
||||
import org.enso.interpreter.runtime.callable.function.FunctionSchema;
|
||||
import org.enso.interpreter.runtime.control.TailCallException;
|
||||
import org.enso.interpreter.runtime.data.ArrayRope;
|
||||
import org.enso.interpreter.runtime.data.EnsoDate;
|
||||
import org.enso.interpreter.runtime.data.EnsoDateTime;
|
||||
@ -407,7 +408,7 @@ public abstract class InvokeMethodNode extends BaseNode {
|
||||
invokeFunctionNode.getDefaultsExecutionMode(),
|
||||
invokeFunctionNode.getArgumentsExecutionMode(),
|
||||
thisArgumentPosition));
|
||||
childDispatch.setTailStatus(TailStatus.NOT_TAIL);
|
||||
childDispatch.setTailStatus(getTailStatus());
|
||||
childDispatch.setId(invokeFunctionNode.getId());
|
||||
notifyInserted(childDispatch);
|
||||
}
|
||||
@ -418,8 +419,12 @@ public abstract class InvokeMethodNode extends BaseNode {
|
||||
|
||||
arguments[thisArgumentPosition] = selfWithoutWarnings;
|
||||
|
||||
try {
|
||||
Object result = childDispatch.execute(frame, state, symbol, selfWithoutWarnings, arguments);
|
||||
return WithWarnings.appendTo(EnsoContext.get(this), result, arrOfWarnings);
|
||||
} catch (TailCallException e) {
|
||||
throw new TailCallException(e, arrOfWarnings);
|
||||
}
|
||||
}
|
||||
|
||||
@ExplodeLoop
|
||||
|
@ -5,6 +5,7 @@ import com.oracle.truffle.api.nodes.Node;
|
||||
import com.oracle.truffle.api.nodes.NodeInfo;
|
||||
import org.enso.interpreter.runtime.callable.CallerInfo;
|
||||
import org.enso.interpreter.runtime.callable.function.Function;
|
||||
import org.enso.interpreter.runtime.error.Warning;
|
||||
import org.enso.interpreter.runtime.state.State;
|
||||
|
||||
/**
|
||||
@ -33,6 +34,7 @@ public abstract class CallOptimiserNode extends Node {
|
||||
* @param callerInfo the caller info to pass to the function
|
||||
* @param state the state to pass to the function
|
||||
* @param arguments the arguments to {@code callable}
|
||||
* @param warnings warnings associated with the callable, null if empty
|
||||
* @return the result of executing {@code callable} using {@code arguments}
|
||||
*/
|
||||
public abstract Object executeDispatch(
|
||||
@ -40,5 +42,6 @@ public abstract class CallOptimiserNode extends Node {
|
||||
Function callable,
|
||||
CallerInfo callerInfo,
|
||||
State state,
|
||||
Object[] arguments);
|
||||
Object[] arguments,
|
||||
Warning[] warnings);
|
||||
}
|
||||
|
@ -133,7 +133,7 @@ public class CurryNode extends BaseNode {
|
||||
return value;
|
||||
}
|
||||
} else {
|
||||
var evaluatedVal = loopingCall.executeDispatch(frame, function, callerInfo, state, arguments);
|
||||
var evaluatedVal = loopingCall.executeDispatch(frame, function, callerInfo, state, arguments, null);
|
||||
|
||||
return this.oversaturatedCallableNode.execute(
|
||||
evaluatedVal, frame, state, oversaturatedArguments);
|
||||
@ -154,7 +154,7 @@ public class CurryNode extends BaseNode {
|
||||
return switch (getTailStatus()) {
|
||||
case TAIL_DIRECT -> directCall.executeCall(frame, function, callerInfo, state, arguments);
|
||||
case TAIL_LOOP -> throw new TailCallException(function, callerInfo, arguments);
|
||||
default -> loopingCall.executeDispatch(frame, function, callerInfo, state, arguments);
|
||||
default -> loopingCall.executeDispatch(frame, function, callerInfo, state, arguments, null);
|
||||
};
|
||||
}
|
||||
}
|
||||
|
@ -92,7 +92,7 @@ public abstract class IndirectCurryNode extends Node {
|
||||
return value;
|
||||
}
|
||||
} else {
|
||||
var evaluatedVal = loopingCall.executeDispatch(frame, function, callerInfo, state, arguments);
|
||||
var evaluatedVal = loopingCall.executeDispatch(frame, function, callerInfo, state, arguments, null);
|
||||
|
||||
return oversaturatedCallableNode.execute(
|
||||
evaluatedVal,
|
||||
@ -129,7 +129,7 @@ public abstract class IndirectCurryNode extends Node {
|
||||
case TAIL_LOOP:
|
||||
throw new TailCallException(function, callerInfo, arguments);
|
||||
default:
|
||||
return loopingCall.executeDispatch(frame, function, callerInfo, state, arguments);
|
||||
return loopingCall.executeDispatch(frame, function, callerInfo, state, arguments, null);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -15,9 +15,12 @@ import com.oracle.truffle.api.nodes.NodeInfo;
|
||||
import com.oracle.truffle.api.nodes.RepeatingNode;
|
||||
import org.enso.interpreter.node.callable.ExecuteCallNode;
|
||||
import org.enso.interpreter.node.callable.ExecuteCallNodeGen;
|
||||
import org.enso.interpreter.runtime.EnsoContext;
|
||||
import org.enso.interpreter.runtime.callable.CallerInfo;
|
||||
import org.enso.interpreter.runtime.callable.function.Function;
|
||||
import org.enso.interpreter.runtime.control.TailCallException;
|
||||
import org.enso.interpreter.runtime.error.Warning;
|
||||
import org.enso.interpreter.runtime.error.WithWarnings;
|
||||
import org.enso.interpreter.runtime.state.State;
|
||||
|
||||
/**
|
||||
@ -54,23 +57,44 @@ public abstract class LoopingCallOptimiserNode extends CallOptimiserNode {
|
||||
* @param loopNode a cached instance of the loop node used by this node
|
||||
* @return the result of executing {@code function} using {@code arguments}
|
||||
*/
|
||||
@Specialization
|
||||
public Object dispatch(
|
||||
@Specialization(guards = "warnings == null")
|
||||
public Object cachedDispatch(
|
||||
Function function,
|
||||
CallerInfo callerInfo,
|
||||
State state,
|
||||
Object[] arguments,
|
||||
Warning[] warnings,
|
||||
@Cached(value = "createLoopNode()") LoopNode loopNode) {
|
||||
return dispatch(function, callerInfo, state, arguments, loopNode);
|
||||
}
|
||||
|
||||
@Specialization(guards = "warnings != null")
|
||||
public Object cachedDispatchWarnings(
|
||||
Function function,
|
||||
CallerInfo callerInfo,
|
||||
State state,
|
||||
Object[] arguments,
|
||||
Warning[] warnings,
|
||||
@Cached(value = "createLoopNode()") LoopNode loopNode) {
|
||||
Object result = dispatch(function, callerInfo, state, arguments, loopNode);
|
||||
return WithWarnings.appendTo(EnsoContext.get(this), result, warnings);
|
||||
}
|
||||
|
||||
private Object dispatch(
|
||||
Function function,
|
||||
CallerInfo callerInfo,
|
||||
State state,
|
||||
Object[] arguments,
|
||||
LoopNode loopNode) {
|
||||
RepeatedCallNode repeatedCallNode = (RepeatedCallNode) loopNode.getRepeatingNode();
|
||||
VirtualFrame frame = repeatedCallNode.createFrame();
|
||||
repeatedCallNode.setNextCall(frame, function, callerInfo, arguments);
|
||||
repeatedCallNode.setState(frame, state);
|
||||
loopNode.execute(frame);
|
||||
|
||||
return repeatedCallNode.getResult(frame);
|
||||
}
|
||||
|
||||
@Specialization(replaces = "dispatch")
|
||||
@Specialization(replaces = "cachedDispatch", guards = "warnings == null")
|
||||
@CompilerDirectives.TruffleBoundary
|
||||
public Object uncachedDispatch(
|
||||
MaterializedFrame frame,
|
||||
@ -78,7 +102,33 @@ public abstract class LoopingCallOptimiserNode extends CallOptimiserNode {
|
||||
CallerInfo callerInfo,
|
||||
State state,
|
||||
Object[] arguments,
|
||||
Warning[] warnings,
|
||||
@Cached ExecuteCallNode executeCallNode) {
|
||||
return loopUntilCompletion(frame, function, callerInfo, state, arguments, executeCallNode);
|
||||
}
|
||||
|
||||
@Specialization(replaces = "cachedDispatchWarnings", guards = "warnings != null")
|
||||
@CompilerDirectives.TruffleBoundary
|
||||
public Object uncachedDispatchWarnings(
|
||||
MaterializedFrame frame,
|
||||
Function function,
|
||||
CallerInfo callerInfo,
|
||||
State state,
|
||||
Object[] arguments,
|
||||
Warning[] warnings,
|
||||
@Cached ExecuteCallNode executeCallNode) {
|
||||
Object result =
|
||||
loopUntilCompletion(frame, function, callerInfo, state, arguments, executeCallNode);
|
||||
return WithWarnings.appendTo(EnsoContext.get(this), result, warnings);
|
||||
}
|
||||
|
||||
private Object loopUntilCompletion(
|
||||
MaterializedFrame frame,
|
||||
Function function,
|
||||
CallerInfo callerInfo,
|
||||
State state,
|
||||
Object[] arguments,
|
||||
ExecuteCallNode executeCallNode) {
|
||||
while (true) {
|
||||
try {
|
||||
return executeCallNode.executeCall(frame, function, callerInfo, state, arguments);
|
||||
|
@ -9,6 +9,7 @@ import org.enso.interpreter.node.callable.ExecuteCallNodeGen;
|
||||
import org.enso.interpreter.runtime.callable.CallerInfo;
|
||||
import org.enso.interpreter.runtime.callable.function.Function;
|
||||
import org.enso.interpreter.runtime.control.TailCallException;
|
||||
import org.enso.interpreter.runtime.error.Warning;
|
||||
import org.enso.interpreter.runtime.state.State;
|
||||
|
||||
/**
|
||||
@ -40,6 +41,7 @@ public class SimpleCallOptimiserNode extends CallOptimiserNode {
|
||||
* @param callerInfo the caller info to pass to the function
|
||||
* @param state the state to pass to the function
|
||||
* @param arguments the arguments to {@code function}
|
||||
* @param warnings warnings associated with the callable, null if empty
|
||||
* @return the result of executing {@code function} using {@code arguments}
|
||||
*/
|
||||
@Override
|
||||
@ -48,7 +50,8 @@ public class SimpleCallOptimiserNode extends CallOptimiserNode {
|
||||
Function function,
|
||||
CallerInfo callerInfo,
|
||||
State state,
|
||||
Object[] arguments) {
|
||||
Object[] arguments,
|
||||
Warning[] warnings) {
|
||||
try {
|
||||
return executeCallNode.executeCall(frame, function, callerInfo, state, arguments);
|
||||
} catch (TailCallException e) {
|
||||
@ -65,7 +68,7 @@ public class SimpleCallOptimiserNode extends CallOptimiserNode {
|
||||
}
|
||||
}
|
||||
return next.executeDispatch(
|
||||
frame, e.getFunction(), e.getCallerInfo(), state, e.getArguments());
|
||||
frame, e.getFunction(), e.getCallerInfo(), state, e.getArguments(), e.getWarnings());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -67,7 +67,7 @@ public abstract class ThunkExecutorNode extends Node {
|
||||
return callNode.call(Function.ArgumentsHelper.buildArguments(function, state));
|
||||
} catch (TailCallException e) {
|
||||
return loopingCallOptimiserNode.executeDispatch(
|
||||
frame, e.getFunction(), e.getCallerInfo(), state, e.getArguments());
|
||||
frame, e.getFunction(), e.getCallerInfo(), state, e.getArguments(), e.getWarnings());
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -89,7 +89,7 @@ public abstract class ThunkExecutorNode extends Node {
|
||||
function.getCallTarget(), Function.ArgumentsHelper.buildArguments(function, state));
|
||||
} catch (TailCallException e) {
|
||||
return loopingCallOptimiserNode.executeDispatch(
|
||||
frame, e.getFunction(), e.getCallerInfo(), state, e.getArguments());
|
||||
frame, e.getFunction(), e.getCallerInfo(), state, e.getArguments(), e.getWarnings());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -811,8 +811,8 @@ public abstract class SortVectorNode extends Node {
|
||||
Object yConverted;
|
||||
if (hasCustomOnFunc) {
|
||||
// onFunc cannot have `self` argument, we assume it has just one argument.
|
||||
xConverted = callNode.executeDispatch(null, onFunc.get(x), null, state, new Object[]{x});
|
||||
yConverted = callNode.executeDispatch(null, onFunc.get(y), null, state, new Object[]{y});
|
||||
xConverted = callNode.executeDispatch(null, onFunc.get(x), null, state, new Object[]{x}, null);
|
||||
yConverted = callNode.executeDispatch(null, onFunc.get(y), null, state, new Object[]{y}, null);
|
||||
} else {
|
||||
xConverted = x;
|
||||
yConverted = y;
|
||||
@ -823,7 +823,7 @@ public abstract class SortVectorNode extends Node {
|
||||
} else {
|
||||
args = new Object[] {xConverted, yConverted};
|
||||
}
|
||||
Object res = callNode.executeDispatch(null, compareFunc.get(xConverted), null, state, args);
|
||||
Object res = callNode.executeDispatch(null, compareFunc.get(xConverted), null, state, args, null);
|
||||
if (res == less) {
|
||||
return ascending ? -1 : 1;
|
||||
} else if (res == equal) {
|
||||
|
@ -636,7 +636,8 @@ public final class Module implements TruffleObject {
|
||||
eval.getFunction(),
|
||||
callerInfo,
|
||||
context.emptyState(),
|
||||
new Object[] {builtins.debug(), Text.create(expr)});
|
||||
new Object[] {builtins.debug(), Text.create(expr)},
|
||||
null);
|
||||
}
|
||||
|
||||
private static Object generateDocs(Module module, EnsoContext context) {
|
||||
|
@ -3,6 +3,7 @@ package org.enso.interpreter.runtime.control;
|
||||
import com.oracle.truffle.api.nodes.ControlFlowException;
|
||||
import org.enso.interpreter.runtime.callable.CallerInfo;
|
||||
import org.enso.interpreter.runtime.callable.function.Function;
|
||||
import org.enso.interpreter.runtime.error.Warning;
|
||||
|
||||
/**
|
||||
* Used to model the switch of control-flow from standard stack-based execution to looping.
|
||||
@ -13,18 +14,38 @@ public class TailCallException extends ControlFlowException {
|
||||
private final Function function;
|
||||
private final CallerInfo callerInfo;
|
||||
private final Object[] arguments;
|
||||
private final Warning[] warnings;
|
||||
|
||||
/**
|
||||
* Creates a new exception containing the necessary data to continue computation.
|
||||
*
|
||||
* @param function the function to execute in a loop
|
||||
* @param state the state to pass to the function
|
||||
* @param callerInfo the caller execution context
|
||||
* @param arguments the arguments to {@code function}
|
||||
*/
|
||||
public TailCallException(Function function, CallerInfo callerInfo, Object[] arguments) {
|
||||
this.function = function;
|
||||
this.callerInfo = callerInfo;
|
||||
this.arguments = arguments;
|
||||
this.warnings = null;
|
||||
}
|
||||
|
||||
private TailCallException(
|
||||
Function function, CallerInfo callerInfo, Object[] arguments, Warning[] warnings) {
|
||||
this.function = function;
|
||||
this.callerInfo = callerInfo;
|
||||
this.arguments = arguments;
|
||||
this.warnings = warnings;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new exception containing the necessary data to continue computation.
|
||||
*
|
||||
* @param origin the original tail call exception
|
||||
* @param warnings warnings to be associated with the tail call exception
|
||||
*/
|
||||
public TailCallException(TailCallException origin, Warning[] warnings) {
|
||||
this(origin.getFunction(), origin.getCallerInfo(), origin.getArguments(), warnings);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -53,4 +74,13 @@ public class TailCallException extends ControlFlowException {
|
||||
public CallerInfo getCallerInfo() {
|
||||
return callerInfo;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the warnings that should be appended to the result of calling the function.
|
||||
*
|
||||
* @return the warnings to be appended to the result of the call, or null if empty
|
||||
*/
|
||||
public Warning[] getWarnings() {
|
||||
return warnings;
|
||||
}
|
||||
}
|
||||
|
@ -437,4 +437,20 @@ spec = Test.group "Dataflow Warnings" <|
|
||||
result_non_tail . should_equal 6
|
||||
Warning.get_all result_non_tail . map .value . should_equal ["Foo"]
|
||||
|
||||
Test.specify "should not break TCO when warnings are attached to arguments" <|
|
||||
vec = Vector.new 10000 (i-> i+1)
|
||||
elem1 = Warning.attach "WARNING1" 998
|
||||
vec.contains 998 . should_equal True
|
||||
res1 = vec.contains elem1
|
||||
res1 . should_be_true
|
||||
Warning.get_all res1 . map .value . should_equal ["WARNING1"]
|
||||
|
||||
elem2 = Warning.attach "WARNING2" 9988
|
||||
vec.contains 9988 . should_be_true
|
||||
vec.contains elem2 . should_be_true
|
||||
|
||||
res2 = vec.contains elem2
|
||||
res2 . should_equal True
|
||||
Warning.get_all res2 . map .value . should_equal ["WARNING2"]
|
||||
|
||||
main = Test_Suite.run_main spec
|
||||
|
Loading…
Reference in New Issue
Block a user