UnresolvedSymbol is now accepted by Vector.sort (#6334)

`Vector.sort` does some custom method dispatch logic which always expected a function as `by` and `on` arguments. At the same time, `UnresolvedSymbol` is treated like a (to be resolved) `Function` and under normal circumstances there would be no difference between `_.foo` and `.foo` provided as arguments.

Rather than adding an additional phase that does some form of eta-expansion, to accomodate for this custom dispatch, this change only fixes the problem locally. We accept `Function` and `UnresolvedSymbol` and perform the resolution on the fly. Ideally, we would have a specialization on the latter but again, it would be dependent on the contents of the `Vector` so unclear if that is better.

Closes #6276,

# Important Notes
There was a suggestion to somehow modify our codegen to accomodate for this scenario but I went against it. In fact a lot of name literals have `isMethod` flag and that information is used in the passes but it should not control how (late) codegen is done. If we were to make this more generic, I would suggest maybe to add separate eta-expansion pass. But it could affect other things and could be potentially a significant change with limited potential initially, so potential future work item.
This commit is contained in:
Hubert Plociniczak 2023-04-20 09:58:58 +02:00 committed by GitHub
parent dd4dce2c3f
commit 6d3151f32d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 131 additions and 25 deletions

View File

@ -22,11 +22,13 @@ import java.util.stream.Collectors;
import org.enso.interpreter.dsl.AcceptsError;
import org.enso.interpreter.dsl.BuiltinMethod;
import org.enso.interpreter.node.callable.dispatch.CallOptimiserNode;
import org.enso.interpreter.node.callable.resolver.MethodResolverNode;
import org.enso.interpreter.node.expression.builtin.interop.syntax.HostValueToEnsoNode;
import org.enso.interpreter.node.expression.builtin.meta.EqualsNode;
import org.enso.interpreter.node.expression.builtin.meta.TypeOfNode;
import org.enso.interpreter.node.expression.builtin.text.AnyToTextNode;
import org.enso.interpreter.runtime.EnsoContext;
import org.enso.interpreter.runtime.callable.UnresolvedSymbol;
import org.enso.interpreter.runtime.callable.atom.Atom;
import org.enso.interpreter.runtime.callable.function.Function;
import org.enso.interpreter.runtime.data.Array;
@ -39,6 +41,7 @@ import org.enso.interpreter.runtime.error.PanicException;
import org.enso.interpreter.runtime.error.Warning;
import org.enso.interpreter.runtime.error.WarningsLibrary;
import org.enso.interpreter.runtime.error.WithWarnings;
import org.enso.interpreter.runtime.library.dispatch.TypesLibrary;
import org.enso.interpreter.runtime.state.State;
/**
@ -179,10 +182,12 @@ public abstract class SortVectorNode extends Node {
long problemBehaviorNum,
@CachedLibrary(limit = "10") InteropLibrary interop,
@CachedLibrary(limit = "5") WarningsLibrary warningsLib,
@CachedLibrary(limit = "5") TypesLibrary typesLib,
@Cached LessThanNode lessThanNode,
@Cached EqualsNode equalsNode,
@Cached TypeOfNode typeOfNode,
@Cached AnyToTextNode toTextNode,
@Cached MethodResolverNode methodResolverNode,
@Cached(value = "build()", uncached = "build()") HostValueToEnsoNode hostValueToEnsoNode,
@Cached(value = "build()", uncached = "build()") CallOptimiserNode callNode) {
var problemBehavior = ProblemBehavior.fromInt((int) problemBehaviorNum);
@ -230,7 +235,9 @@ public abstract class SortVectorNode extends Node {
less,
equal,
greater,
interop);
interop,
typesLib,
methodResolverNode);
}
group.elems.sort(javaComparator);
if (javaComparator.hasWarnings()) {
@ -647,6 +654,94 @@ public abstract class SortVectorNode extends Node {
}
}
/**
* Helper class that returns the comparator function.
*
* The class is introduced to handle the presence of {@code UnresolvedSymbol},
* as the comparator function, which has to be first resolved before it
* can be used to compare values.
*/
private abstract class Compare {
/**
* Test if the comparator function has self argument.
*
* @param definedOn the value on which the function is defined on.
* @return true if self argument is present, false otherwise.
*/
abstract boolean hasFunctionSelfArgument(Object definedOn);
/**
* Return a comparator function.
*
* @param arg the value on which the function is defined on.
* @return a non-null comparator function.
*/
abstract Function get(Object arg);
}
private final class CompareFromFunction extends Compare {
private final Function function;
private CompareFromFunction(Function function) {
this.function = function;
}
@Override
boolean hasFunctionSelfArgument(Object definedOn) {
if (function.getSchema().getArgumentsCount() > 0) {
return function.getSchema().getArgumentInfos()[0].getName().equals("self");
} else {
return false;
}
}
@Override
Function get(Object arg) {
return function;
}
}
private class CompareFromUnresolvedSymbol extends Compare {
private final UnresolvedSymbol unresolvedSymbol;
private final MethodResolverNode methodResolverNode;
private final TypesLibrary typesLibrary;
private @CompilerDirectives.CompilationFinal Function resolvedFunction;
private CompareFromUnresolvedSymbol(UnresolvedSymbol unresolvedSymbol,
MethodResolverNode methodResolvedNode,
TypesLibrary typesLibrary) {
this.unresolvedSymbol = unresolvedSymbol;
this.methodResolverNode = methodResolvedNode;
this.typesLibrary = typesLibrary;
}
@Override
boolean hasFunctionSelfArgument(Object definedOn) {
ensureSymbolIsResolved(definedOn);
return resolvedFunction.getSchema().getArgumentsCount() > 0 &&
resolvedFunction.getSchema().getArgumentInfos()[0].getName().equals("self");
}
@Override
Function get(Object arg) {
ensureSymbolIsResolved(arg);
return resolvedFunction;
}
private void ensureSymbolIsResolved(Object definedOn) {
if (resolvedFunction == null) {
resolvedFunction = methodResolverNode.expectNonNull(definedOn, typesLibrary.getType(definedOn), unresolvedSymbol);
}
}
}
/**
* Comparator for any values. This comparator compares the values by calling back to Enso (by
* {@link #compareFunc}), rather than using compare nodes (i.e. {@link LessThanNode}). directly,
@ -659,9 +754,9 @@ public abstract class SortVectorNode extends Node {
* Either function from `by` parameter to the `Vector.sort` method, or the `compare` function
* extracted from the comparator for the appropriate group.
*/
private final Function compareFunc;
private final Compare compareFunc;
private final Function onFunc;
private final Compare onFunc;
private final boolean hasCustomOnFunc;
private final Type comparator;
private final CallOptimiserNode callNode;
@ -682,20 +777,22 @@ public abstract class SortVectorNode extends Node {
Atom less,
Atom equal,
Atom greater,
InteropLibrary interop) {
InteropLibrary interop,
TypesLibrary typesLibrary,
MethodResolverNode methodResolverNode) {
super(toTextNode, problemBehavior, interop);
assert compareFunc != null;
assert comparator != null;
this.comparator = comparator;
this.state = state;
this.ascending = ascending;
this.compareFunc = checkAndConvertByFunc(compareFunc);
this.compareFunc = checkAndConvertByFunc(compareFunc, typesLibrary, methodResolverNode);
if (interop.isNull(onFunc)) {
this.hasCustomOnFunc = false;
this.onFunc = null;
} else {
this.hasCustomOnFunc = true;
this.onFunc = checkAndConvertOnFunc(onFunc);
this.onFunc = checkAndConvertOnFunc(onFunc, typesLibrary, methodResolverNode);
}
this.callNode = callNode;
this.less = less;
@ -709,19 +806,19 @@ public abstract class SortVectorNode extends Node {
Object yConverted;
if (hasCustomOnFunc) {
// onFunc cannot have `self` argument, we assume it has just one argument.
xConverted = callNode.executeDispatch(onFunc, null, state, new Object[] {x});
yConverted = callNode.executeDispatch(onFunc, null, state, new Object[] {y});
xConverted = callNode.executeDispatch(onFunc.get(x), null, state, new Object[]{x});
yConverted = callNode.executeDispatch(onFunc.get(y), null, state, new Object[]{y});
} else {
xConverted = x;
yConverted = y;
}
Object[] args;
if (hasFunctionSelfArgument(compareFunc)) {
if (compareFunc.hasFunctionSelfArgument(xConverted)) {
args = new Object[] {comparator, xConverted, yConverted};
} else {
args = new Object[] {xConverted, yConverted};
}
Object res = callNode.executeDispatch(compareFunc, null, state, args);
Object res = callNode.executeDispatch(compareFunc.get(xConverted), null, state, args);
if (res == less) {
return ascending ? -1 : 1;
} else if (res == equal) {
@ -738,43 +835,43 @@ public abstract class SortVectorNode extends Node {
}
}
private boolean hasFunctionSelfArgument(Function function) {
if (function.getSchema().getArgumentsCount() > 0) {
return function.getSchema().getArgumentInfos()[0].getName().equals("self");
} else {
return false;
}
}
/**
* Checks value given for {@code by} parameter and converts it to {@link Function}. Throw a
* dataflow error otherwise.
*/
private Function checkAndConvertByFunc(Object byFuncObj) {
private Compare checkAndConvertByFunc(Object byFuncObj, TypesLibrary typesLibrary, MethodResolverNode methodResolverNode) {
return checkAndConvertFunction(
byFuncObj, "Unsupported argument for `by`, expected a method with two arguments", 2, 3);
byFuncObj, "Unsupported argument for `by`, expected a method with two arguments", 2, 3,
typesLibrary, methodResolverNode);
}
/**
* Checks the value given for {@code on} parameter and converts it to {@link Function}. Throws a
* dataflow error otherwise.
*/
private Function checkAndConvertOnFunc(Object onFuncObj) {
private Compare checkAndConvertOnFunc(Object onFuncObj, TypesLibrary typesLibrary, MethodResolverNode methodResolverNode) {
return checkAndConvertFunction(
onFuncObj, "Unsupported argument for `on`, expected a method with one argument", 1, 1);
onFuncObj, "Unsupported argument for `on`, expected a method with one argument", 1, 1,
typesLibrary, methodResolverNode);
}
/**
* @param minArgCount Minimal count of arguments without a default value.
* @param maxArgCount Maximal count of argument without a default value.
* @param methodResolverNode node for resolving unresolved symbols.
* @param typesLibrary types library for resolving the dispatch type for unresolved symbols.
*/
private Function checkAndConvertFunction(
Object funcObj, String errMsg, int minArgCount, int maxArgCount) {
private Compare checkAndConvertFunction(
Object funcObj, String errMsg, int minArgCount, int maxArgCount,
TypesLibrary typesLibrary, MethodResolverNode methodResolverNode) {
if (funcObj instanceof UnresolvedSymbol unresolved) {
return new CompareFromUnresolvedSymbol(unresolved, methodResolverNode, typesLibrary);
}
var err = new IllegalArgumentException(errMsg + ", got " + funcObj);
if (funcObj instanceof Function func) {
var argCount = getNumberOfNonDefaultArguments(func);
if (minArgCount <= argCount && argCount <= maxArgCount) {
return func;
return new CompareFromFunction(func);
} else {
throw err;
}

View File

@ -49,6 +49,11 @@ make_partially_sorted_vec n =
run_length.put (run_length.get - 1)
num
type Int
Value v
identity self = self
# The Benchmarks ==============================================================
@ -56,6 +61,7 @@ bench =
sorted_vec = make_sorted_ascending_vec vector_size
partially_sorted_vec = make_partially_sorted_vec vector_size
random_vec = Utils.make_random_vec vector_size
random_vec_wrapped = random_vec.map (v -> Int.Value v)
projection = x -> x % 10
comparator = l -> r -> Ordering.compare l r
@ -66,6 +72,8 @@ bench =
Bench.measure (random_vec.sort) "Random Elements Ascending" iter_size num_iterations
Bench.measure (random_vec.sort Sort_Direction.Descending) "Random Elements Descending" iter_size num_iterations
Bench.measure (random_vec.sort on=projection) "Sorting with a Custom Projection" iter_size num_iterations
Bench.measure (random_vec_wrapped.sort on=(_.identity)) "Sorting with an identity function" iter_size num_iterations
Bench.measure (random_vec_wrapped.sort on=(.identity)) "Sorting with an (unresolved) identity function" iter_size num_iterations
Bench.measure (random_vec.sort by=comparator) "Sorting with the Default_Ordered_Comparator" iter_size num_iterations
main = bench

View File

@ -579,6 +579,7 @@ type_spec name alter = Test.group name <|
small_vec = alter [T.Value 1 8, T.Value 1 3, T.Value -20 0, T.Value -1 1, T.Value -1 10, T.Value 4 0]
small_expected = [T.Value -20 0, T.Value 4 0, T.Value -1 1, T.Value 1 3, T.Value 1 8, T.Value -1 10]
small_vec.sort (on = _.b) . should_equal small_expected
small_vec.sort (on = .b) . should_equal small_expected
Test.specify "should be able to use a custom compare function" <|
small_vec = alter [2, 7, -3, 383, -392, 28, -90]