Improve TCO in the presence of warnings (#7116)

Partially revert https://github.com/enso-org/enso/pull/6849, which introduced a regression in TCO in the presence of warnings. Rather than modifying the tail call status, `TailCallException` now propagates the extracted warnings and appends them to the final result.

Closes #7093

# Important Notes
Compared to the previous attempt we don't pay the penalty of adding the warnings or even checking for them because it is being dealt in a separate specialization.
This commit is contained in:
Hubert Plociniczak 2023-06-26 14:38:36 +02:00 committed by GitHub
parent c4f19e7d66
commit ae4666c4d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 154 additions and 31 deletions

View File

@ -21,6 +21,7 @@ import org.enso.interpreter.runtime.callable.argument.CallArgumentInfo;
import org.enso.interpreter.runtime.callable.atom.Atom;
import org.enso.interpreter.runtime.callable.atom.AtomConstructor;
import org.enso.interpreter.runtime.callable.function.Function;
import org.enso.interpreter.runtime.control.TailCallException;
import org.enso.interpreter.runtime.error.DataflowError;
import org.enso.interpreter.runtime.error.PanicException;
import org.enso.interpreter.runtime.error.PanicSentinel;
@ -264,6 +265,15 @@ public abstract class InvokeCallableNode extends BaseNode {
State state,
Object[] arguments,
@CachedLibrary(limit = "3") WarningsLibrary warnings) {
Warning[] extracted;
Object callable;
try {
extracted = warnings.getWarnings(warning, null);
callable = warnings.removeWarnings(warning);
} catch (UnsupportedMessageException e) {
throw CompilerDirectives.shouldNotReachHere(e);
}
try {
if (childDispatch == null) {
CompilerDirectives.transferToInterpreterAndInvalidate();
@ -277,7 +287,7 @@ public abstract class InvokeCallableNode extends BaseNode {
invokeFunctionNode.getSchema(),
invokeFunctionNode.getDefaultsExecutionMode(),
invokeFunctionNode.getArgumentsExecutionMode()));
childDispatch.setTailStatus(TailStatus.NOT_TAIL);
childDispatch.setTailStatus(getTailStatus());
childDispatch.setId(invokeFunctionNode.getId());
notifyInserted(childDispatch);
}
@ -287,12 +297,12 @@ public abstract class InvokeCallableNode extends BaseNode {
}
var result = childDispatch.execute(
warnings.removeWarnings(warning),
callable,
callerFrame,
state,
arguments);
Warning[] extracted = warnings.getWarnings(warning, null);
if (result instanceof DataflowError) {
return result;
} else if (result instanceof WithWarnings withWarnings) {
@ -300,8 +310,8 @@ public abstract class InvokeCallableNode extends BaseNode {
} else {
return WithWarnings.wrap(EnsoContext.get(this), result, extracted);
}
} catch (UnsupportedMessageException e) {
throw CompilerDirectives.shouldNotReachHere(e);
} catch (TailCallException e) {
throw new TailCallException(e, extracted);
}
}

View File

@ -17,6 +17,7 @@ import org.enso.interpreter.runtime.EnsoContext;
import org.enso.interpreter.runtime.callable.UnresolvedConversion;
import org.enso.interpreter.runtime.callable.argument.CallArgumentInfo;
import org.enso.interpreter.runtime.callable.function.Function;
import org.enso.interpreter.runtime.control.TailCallException;
import org.enso.interpreter.runtime.data.ArrayRope;
import org.enso.interpreter.runtime.data.Type;
import org.enso.interpreter.runtime.data.text.Text;
@ -162,7 +163,7 @@ public abstract class InvokeConversionNode extends BaseNode {
invokeFunctionNode.getDefaultsExecutionMode(),
invokeFunctionNode.getArgumentsExecutionMode(),
thatArgumentPosition));
childDispatch.setTailStatus(TailStatus.NOT_TAIL);
childDispatch.setTailStatus(getTailStatus());
childDispatch.setId(invokeFunctionNode.getId());
notifyInserted(childDispatch);
}
@ -170,11 +171,15 @@ public abstract class InvokeConversionNode extends BaseNode {
lock.unlock();
}
}
arguments[thatArgumentPosition] = that.getValue();
Object value = that.getValue();
arguments[thatArgumentPosition] = value;
ArrayRope<Warning> warnings = that.getReassignedWarningsAsRope(this);
Object result =
childDispatch.execute(frame, state, conversion, self, that.getValue(), arguments);
try {
Object result = childDispatch.execute(frame, state, conversion, self, value, arguments);
return WithWarnings.appendTo(EnsoContext.get(this), result, warnings);
} catch (TailCallException e) {
throw new TailCallException(e, warnings.toArray(Warning[]::new));
}
}
@Specialization(guards = "interop.isString(that)")

View File

@ -40,6 +40,7 @@ import org.enso.interpreter.runtime.callable.argument.ArgumentDefinition;
import org.enso.interpreter.runtime.callable.argument.CallArgumentInfo;
import org.enso.interpreter.runtime.callable.function.Function;
import org.enso.interpreter.runtime.callable.function.FunctionSchema;
import org.enso.interpreter.runtime.control.TailCallException;
import org.enso.interpreter.runtime.data.ArrayRope;
import org.enso.interpreter.runtime.data.EnsoDate;
import org.enso.interpreter.runtime.data.EnsoDateTime;
@ -407,7 +408,7 @@ public abstract class InvokeMethodNode extends BaseNode {
invokeFunctionNode.getDefaultsExecutionMode(),
invokeFunctionNode.getArgumentsExecutionMode(),
thisArgumentPosition));
childDispatch.setTailStatus(TailStatus.NOT_TAIL);
childDispatch.setTailStatus(getTailStatus());
childDispatch.setId(invokeFunctionNode.getId());
notifyInserted(childDispatch);
}
@ -418,8 +419,12 @@ public abstract class InvokeMethodNode extends BaseNode {
arguments[thisArgumentPosition] = selfWithoutWarnings;
try {
Object result = childDispatch.execute(frame, state, symbol, selfWithoutWarnings, arguments);
return WithWarnings.appendTo(EnsoContext.get(this), result, arrOfWarnings);
} catch (TailCallException e) {
throw new TailCallException(e, arrOfWarnings);
}
}
@ExplodeLoop

View File

@ -5,6 +5,7 @@ import com.oracle.truffle.api.nodes.Node;
import com.oracle.truffle.api.nodes.NodeInfo;
import org.enso.interpreter.runtime.callable.CallerInfo;
import org.enso.interpreter.runtime.callable.function.Function;
import org.enso.interpreter.runtime.error.Warning;
import org.enso.interpreter.runtime.state.State;
/**
@ -33,6 +34,7 @@ public abstract class CallOptimiserNode extends Node {
* @param callerInfo the caller info to pass to the function
* @param state the state to pass to the function
* @param arguments the arguments to {@code callable}
* @param warnings warnings associated with the callable, null if empty
* @return the result of executing {@code callable} using {@code arguments}
*/
public abstract Object executeDispatch(
@ -40,5 +42,6 @@ public abstract class CallOptimiserNode extends Node {
Function callable,
CallerInfo callerInfo,
State state,
Object[] arguments);
Object[] arguments,
Warning[] warnings);
}

View File

@ -133,7 +133,7 @@ public class CurryNode extends BaseNode {
return value;
}
} else {
var evaluatedVal = loopingCall.executeDispatch(frame, function, callerInfo, state, arguments);
var evaluatedVal = loopingCall.executeDispatch(frame, function, callerInfo, state, arguments, null);
return this.oversaturatedCallableNode.execute(
evaluatedVal, frame, state, oversaturatedArguments);
@ -154,7 +154,7 @@ public class CurryNode extends BaseNode {
return switch (getTailStatus()) {
case TAIL_DIRECT -> directCall.executeCall(frame, function, callerInfo, state, arguments);
case TAIL_LOOP -> throw new TailCallException(function, callerInfo, arguments);
default -> loopingCall.executeDispatch(frame, function, callerInfo, state, arguments);
default -> loopingCall.executeDispatch(frame, function, callerInfo, state, arguments, null);
};
}
}

View File

@ -92,7 +92,7 @@ public abstract class IndirectCurryNode extends Node {
return value;
}
} else {
var evaluatedVal = loopingCall.executeDispatch(frame, function, callerInfo, state, arguments);
var evaluatedVal = loopingCall.executeDispatch(frame, function, callerInfo, state, arguments, null);
return oversaturatedCallableNode.execute(
evaluatedVal,
@ -129,7 +129,7 @@ public abstract class IndirectCurryNode extends Node {
case TAIL_LOOP:
throw new TailCallException(function, callerInfo, arguments);
default:
return loopingCall.executeDispatch(frame, function, callerInfo, state, arguments);
return loopingCall.executeDispatch(frame, function, callerInfo, state, arguments, null);
}
}
}

View File

@ -15,9 +15,12 @@ import com.oracle.truffle.api.nodes.NodeInfo;
import com.oracle.truffle.api.nodes.RepeatingNode;
import org.enso.interpreter.node.callable.ExecuteCallNode;
import org.enso.interpreter.node.callable.ExecuteCallNodeGen;
import org.enso.interpreter.runtime.EnsoContext;
import org.enso.interpreter.runtime.callable.CallerInfo;
import org.enso.interpreter.runtime.callable.function.Function;
import org.enso.interpreter.runtime.control.TailCallException;
import org.enso.interpreter.runtime.error.Warning;
import org.enso.interpreter.runtime.error.WithWarnings;
import org.enso.interpreter.runtime.state.State;
/**
@ -54,23 +57,44 @@ public abstract class LoopingCallOptimiserNode extends CallOptimiserNode {
* @param loopNode a cached instance of the loop node used by this node
* @return the result of executing {@code function} using {@code arguments}
*/
@Specialization
public Object dispatch(
@Specialization(guards = "warnings == null")
public Object cachedDispatch(
Function function,
CallerInfo callerInfo,
State state,
Object[] arguments,
Warning[] warnings,
@Cached(value = "createLoopNode()") LoopNode loopNode) {
return dispatch(function, callerInfo, state, arguments, loopNode);
}
@Specialization(guards = "warnings != null")
public Object cachedDispatchWarnings(
Function function,
CallerInfo callerInfo,
State state,
Object[] arguments,
Warning[] warnings,
@Cached(value = "createLoopNode()") LoopNode loopNode) {
Object result = dispatch(function, callerInfo, state, arguments, loopNode);
return WithWarnings.appendTo(EnsoContext.get(this), result, warnings);
}
private Object dispatch(
Function function,
CallerInfo callerInfo,
State state,
Object[] arguments,
LoopNode loopNode) {
RepeatedCallNode repeatedCallNode = (RepeatedCallNode) loopNode.getRepeatingNode();
VirtualFrame frame = repeatedCallNode.createFrame();
repeatedCallNode.setNextCall(frame, function, callerInfo, arguments);
repeatedCallNode.setState(frame, state);
loopNode.execute(frame);
return repeatedCallNode.getResult(frame);
}
@Specialization(replaces = "dispatch")
@Specialization(replaces = "cachedDispatch", guards = "warnings == null")
@CompilerDirectives.TruffleBoundary
public Object uncachedDispatch(
MaterializedFrame frame,
@ -78,7 +102,33 @@ public abstract class LoopingCallOptimiserNode extends CallOptimiserNode {
CallerInfo callerInfo,
State state,
Object[] arguments,
Warning[] warnings,
@Cached ExecuteCallNode executeCallNode) {
return loopUntilCompletion(frame, function, callerInfo, state, arguments, executeCallNode);
}
@Specialization(replaces = "cachedDispatchWarnings", guards = "warnings != null")
@CompilerDirectives.TruffleBoundary
public Object uncachedDispatchWarnings(
MaterializedFrame frame,
Function function,
CallerInfo callerInfo,
State state,
Object[] arguments,
Warning[] warnings,
@Cached ExecuteCallNode executeCallNode) {
Object result =
loopUntilCompletion(frame, function, callerInfo, state, arguments, executeCallNode);
return WithWarnings.appendTo(EnsoContext.get(this), result, warnings);
}
private Object loopUntilCompletion(
MaterializedFrame frame,
Function function,
CallerInfo callerInfo,
State state,
Object[] arguments,
ExecuteCallNode executeCallNode) {
while (true) {
try {
return executeCallNode.executeCall(frame, function, callerInfo, state, arguments);

View File

@ -9,6 +9,7 @@ import org.enso.interpreter.node.callable.ExecuteCallNodeGen;
import org.enso.interpreter.runtime.callable.CallerInfo;
import org.enso.interpreter.runtime.callable.function.Function;
import org.enso.interpreter.runtime.control.TailCallException;
import org.enso.interpreter.runtime.error.Warning;
import org.enso.interpreter.runtime.state.State;
/**
@ -40,6 +41,7 @@ public class SimpleCallOptimiserNode extends CallOptimiserNode {
* @param callerInfo the caller info to pass to the function
* @param state the state to pass to the function
* @param arguments the arguments to {@code function}
* @param warnings warnings associated with the callable, null if empty
* @return the result of executing {@code function} using {@code arguments}
*/
@Override
@ -48,7 +50,8 @@ public class SimpleCallOptimiserNode extends CallOptimiserNode {
Function function,
CallerInfo callerInfo,
State state,
Object[] arguments) {
Object[] arguments,
Warning[] warnings) {
try {
return executeCallNode.executeCall(frame, function, callerInfo, state, arguments);
} catch (TailCallException e) {
@ -65,7 +68,7 @@ public class SimpleCallOptimiserNode extends CallOptimiserNode {
}
}
return next.executeDispatch(
frame, e.getFunction(), e.getCallerInfo(), state, e.getArguments());
frame, e.getFunction(), e.getCallerInfo(), state, e.getArguments(), e.getWarnings());
}
}
}

View File

@ -67,7 +67,7 @@ public abstract class ThunkExecutorNode extends Node {
return callNode.call(Function.ArgumentsHelper.buildArguments(function, state));
} catch (TailCallException e) {
return loopingCallOptimiserNode.executeDispatch(
frame, e.getFunction(), e.getCallerInfo(), state, e.getArguments());
frame, e.getFunction(), e.getCallerInfo(), state, e.getArguments(), e.getWarnings());
}
}
}
@ -89,7 +89,7 @@ public abstract class ThunkExecutorNode extends Node {
function.getCallTarget(), Function.ArgumentsHelper.buildArguments(function, state));
} catch (TailCallException e) {
return loopingCallOptimiserNode.executeDispatch(
frame, e.getFunction(), e.getCallerInfo(), state, e.getArguments());
frame, e.getFunction(), e.getCallerInfo(), state, e.getArguments(), e.getWarnings());
}
}
}

View File

@ -811,8 +811,8 @@ public abstract class SortVectorNode extends Node {
Object yConverted;
if (hasCustomOnFunc) {
// onFunc cannot have `self` argument, we assume it has just one argument.
xConverted = callNode.executeDispatch(null, onFunc.get(x), null, state, new Object[]{x});
yConverted = callNode.executeDispatch(null, onFunc.get(y), null, state, new Object[]{y});
xConverted = callNode.executeDispatch(null, onFunc.get(x), null, state, new Object[]{x}, null);
yConverted = callNode.executeDispatch(null, onFunc.get(y), null, state, new Object[]{y}, null);
} else {
xConverted = x;
yConverted = y;
@ -823,7 +823,7 @@ public abstract class SortVectorNode extends Node {
} else {
args = new Object[] {xConverted, yConverted};
}
Object res = callNode.executeDispatch(null, compareFunc.get(xConverted), null, state, args);
Object res = callNode.executeDispatch(null, compareFunc.get(xConverted), null, state, args, null);
if (res == less) {
return ascending ? -1 : 1;
} else if (res == equal) {

View File

@ -636,7 +636,8 @@ public final class Module implements TruffleObject {
eval.getFunction(),
callerInfo,
context.emptyState(),
new Object[] {builtins.debug(), Text.create(expr)});
new Object[] {builtins.debug(), Text.create(expr)},
null);
}
private static Object generateDocs(Module module, EnsoContext context) {

View File

@ -3,6 +3,7 @@ package org.enso.interpreter.runtime.control;
import com.oracle.truffle.api.nodes.ControlFlowException;
import org.enso.interpreter.runtime.callable.CallerInfo;
import org.enso.interpreter.runtime.callable.function.Function;
import org.enso.interpreter.runtime.error.Warning;
/**
* Used to model the switch of control-flow from standard stack-based execution to looping.
@ -13,18 +14,38 @@ public class TailCallException extends ControlFlowException {
private final Function function;
private final CallerInfo callerInfo;
private final Object[] arguments;
private final Warning[] warnings;
/**
* Creates a new exception containing the necessary data to continue computation.
*
* @param function the function to execute in a loop
* @param state the state to pass to the function
* @param callerInfo the caller execution context
* @param arguments the arguments to {@code function}
*/
public TailCallException(Function function, CallerInfo callerInfo, Object[] arguments) {
this.function = function;
this.callerInfo = callerInfo;
this.arguments = arguments;
this.warnings = null;
}
private TailCallException(
Function function, CallerInfo callerInfo, Object[] arguments, Warning[] warnings) {
this.function = function;
this.callerInfo = callerInfo;
this.arguments = arguments;
this.warnings = warnings;
}
/**
* Creates a new exception containing the necessary data to continue computation.
*
* @param origin the original tail call exception
* @param warnings warnings to be associated with the tail call exception
*/
public TailCallException(TailCallException origin, Warning[] warnings) {
this(origin.getFunction(), origin.getCallerInfo(), origin.getArguments(), warnings);
}
/**
@ -53,4 +74,13 @@ public class TailCallException extends ControlFlowException {
public CallerInfo getCallerInfo() {
return callerInfo;
}
/**
* Gets the warnings that should be appended to the result of calling the function.
*
* @return the warnings to be appended to the result of the call, or null if empty
*/
public Warning[] getWarnings() {
return warnings;
}
}

View File

@ -437,4 +437,20 @@ spec = Test.group "Dataflow Warnings" <|
result_non_tail . should_equal 6
Warning.get_all result_non_tail . map .value . should_equal ["Foo"]
Test.specify "should not break TCO when warnings are attached to arguments" <|
vec = Vector.new 10000 (i-> i+1)
elem1 = Warning.attach "WARNING1" 998
vec.contains 998 . should_equal True
res1 = vec.contains elem1
res1 . should_be_true
Warning.get_all res1 . map .value . should_equal ["WARNING1"]
elem2 = Warning.attach "WARNING2" 9988
vec.contains 9988 . should_be_true
vec.contains elem2 . should_be_true
res2 = vec.contains elem2
res2 . should_equal True
Warning.get_all res2 . map .value . should_equal ["WARNING2"]
main = Test_Suite.run_main spec