mirror of
https://github.com/enso-org/enso.git
synced 2024-12-23 13:02:07 +03:00
UnresolvedSymbol is now accepted by Vector.sort (#6334)
`Vector.sort` does some custom method dispatch logic which always expected a function as `by` and `on` arguments. At the same time, `UnresolvedSymbol` is treated like a (to be resolved) `Function` and under normal circumstances there would be no difference between `_.foo` and `.foo` provided as arguments. Rather than adding an additional phase that does some form of eta-expansion, to accomodate for this custom dispatch, this change only fixes the problem locally. We accept `Function` and `UnresolvedSymbol` and perform the resolution on the fly. Ideally, we would have a specialization on the latter but again, it would be dependent on the contents of the `Vector` so unclear if that is better. Closes #6276, # Important Notes There was a suggestion to somehow modify our codegen to accomodate for this scenario but I went against it. In fact a lot of name literals have `isMethod` flag and that information is used in the passes but it should not control how (late) codegen is done. If we were to make this more generic, I would suggest maybe to add separate eta-expansion pass. But it could affect other things and could be potentially a significant change with limited potential initially, so potential future work item.
This commit is contained in:
parent
dd4dce2c3f
commit
6d3151f32d
@ -22,11 +22,13 @@ import java.util.stream.Collectors;
|
||||
import org.enso.interpreter.dsl.AcceptsError;
|
||||
import org.enso.interpreter.dsl.BuiltinMethod;
|
||||
import org.enso.interpreter.node.callable.dispatch.CallOptimiserNode;
|
||||
import org.enso.interpreter.node.callable.resolver.MethodResolverNode;
|
||||
import org.enso.interpreter.node.expression.builtin.interop.syntax.HostValueToEnsoNode;
|
||||
import org.enso.interpreter.node.expression.builtin.meta.EqualsNode;
|
||||
import org.enso.interpreter.node.expression.builtin.meta.TypeOfNode;
|
||||
import org.enso.interpreter.node.expression.builtin.text.AnyToTextNode;
|
||||
import org.enso.interpreter.runtime.EnsoContext;
|
||||
import org.enso.interpreter.runtime.callable.UnresolvedSymbol;
|
||||
import org.enso.interpreter.runtime.callable.atom.Atom;
|
||||
import org.enso.interpreter.runtime.callable.function.Function;
|
||||
import org.enso.interpreter.runtime.data.Array;
|
||||
@ -39,6 +41,7 @@ import org.enso.interpreter.runtime.error.PanicException;
|
||||
import org.enso.interpreter.runtime.error.Warning;
|
||||
import org.enso.interpreter.runtime.error.WarningsLibrary;
|
||||
import org.enso.interpreter.runtime.error.WithWarnings;
|
||||
import org.enso.interpreter.runtime.library.dispatch.TypesLibrary;
|
||||
import org.enso.interpreter.runtime.state.State;
|
||||
|
||||
/**
|
||||
@ -179,10 +182,12 @@ public abstract class SortVectorNode extends Node {
|
||||
long problemBehaviorNum,
|
||||
@CachedLibrary(limit = "10") InteropLibrary interop,
|
||||
@CachedLibrary(limit = "5") WarningsLibrary warningsLib,
|
||||
@CachedLibrary(limit = "5") TypesLibrary typesLib,
|
||||
@Cached LessThanNode lessThanNode,
|
||||
@Cached EqualsNode equalsNode,
|
||||
@Cached TypeOfNode typeOfNode,
|
||||
@Cached AnyToTextNode toTextNode,
|
||||
@Cached MethodResolverNode methodResolverNode,
|
||||
@Cached(value = "build()", uncached = "build()") HostValueToEnsoNode hostValueToEnsoNode,
|
||||
@Cached(value = "build()", uncached = "build()") CallOptimiserNode callNode) {
|
||||
var problemBehavior = ProblemBehavior.fromInt((int) problemBehaviorNum);
|
||||
@ -230,7 +235,9 @@ public abstract class SortVectorNode extends Node {
|
||||
less,
|
||||
equal,
|
||||
greater,
|
||||
interop);
|
||||
interop,
|
||||
typesLib,
|
||||
methodResolverNode);
|
||||
}
|
||||
group.elems.sort(javaComparator);
|
||||
if (javaComparator.hasWarnings()) {
|
||||
@ -647,6 +654,94 @@ public abstract class SortVectorNode extends Node {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper class that returns the comparator function.
|
||||
*
|
||||
* The class is introduced to handle the presence of {@code UnresolvedSymbol},
|
||||
* as the comparator function, which has to be first resolved before it
|
||||
* can be used to compare values.
|
||||
*/
|
||||
private abstract class Compare {
|
||||
|
||||
/**
|
||||
* Test if the comparator function has self argument.
|
||||
*
|
||||
* @param definedOn the value on which the function is defined on.
|
||||
* @return true if self argument is present, false otherwise.
|
||||
*/
|
||||
abstract boolean hasFunctionSelfArgument(Object definedOn);
|
||||
|
||||
/**
|
||||
* Return a comparator function.
|
||||
*
|
||||
* @param arg the value on which the function is defined on.
|
||||
* @return a non-null comparator function.
|
||||
*/
|
||||
abstract Function get(Object arg);
|
||||
|
||||
}
|
||||
|
||||
private final class CompareFromFunction extends Compare {
|
||||
|
||||
private final Function function;
|
||||
|
||||
private CompareFromFunction(Function function) {
|
||||
this.function = function;
|
||||
}
|
||||
|
||||
@Override
|
||||
boolean hasFunctionSelfArgument(Object definedOn) {
|
||||
if (function.getSchema().getArgumentsCount() > 0) {
|
||||
return function.getSchema().getArgumentInfos()[0].getName().equals("self");
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
Function get(Object arg) {
|
||||
return function;
|
||||
}
|
||||
}
|
||||
|
||||
private class CompareFromUnresolvedSymbol extends Compare {
|
||||
|
||||
private final UnresolvedSymbol unresolvedSymbol;
|
||||
private final MethodResolverNode methodResolverNode;
|
||||
private final TypesLibrary typesLibrary;
|
||||
|
||||
private @CompilerDirectives.CompilationFinal Function resolvedFunction;
|
||||
|
||||
private CompareFromUnresolvedSymbol(UnresolvedSymbol unresolvedSymbol,
|
||||
MethodResolverNode methodResolvedNode,
|
||||
TypesLibrary typesLibrary) {
|
||||
this.unresolvedSymbol = unresolvedSymbol;
|
||||
this.methodResolverNode = methodResolvedNode;
|
||||
this.typesLibrary = typesLibrary;
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
boolean hasFunctionSelfArgument(Object definedOn) {
|
||||
ensureSymbolIsResolved(definedOn);
|
||||
return resolvedFunction.getSchema().getArgumentsCount() > 0 &&
|
||||
resolvedFunction.getSchema().getArgumentInfos()[0].getName().equals("self");
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
Function get(Object arg) {
|
||||
ensureSymbolIsResolved(arg);
|
||||
return resolvedFunction;
|
||||
}
|
||||
|
||||
private void ensureSymbolIsResolved(Object definedOn) {
|
||||
if (resolvedFunction == null) {
|
||||
resolvedFunction = methodResolverNode.expectNonNull(definedOn, typesLibrary.getType(definedOn), unresolvedSymbol);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Comparator for any values. This comparator compares the values by calling back to Enso (by
|
||||
* {@link #compareFunc}), rather than using compare nodes (i.e. {@link LessThanNode}). directly,
|
||||
@ -659,9 +754,9 @@ public abstract class SortVectorNode extends Node {
|
||||
* Either function from `by` parameter to the `Vector.sort` method, or the `compare` function
|
||||
* extracted from the comparator for the appropriate group.
|
||||
*/
|
||||
private final Function compareFunc;
|
||||
private final Compare compareFunc;
|
||||
|
||||
private final Function onFunc;
|
||||
private final Compare onFunc;
|
||||
private final boolean hasCustomOnFunc;
|
||||
private final Type comparator;
|
||||
private final CallOptimiserNode callNode;
|
||||
@ -682,20 +777,22 @@ public abstract class SortVectorNode extends Node {
|
||||
Atom less,
|
||||
Atom equal,
|
||||
Atom greater,
|
||||
InteropLibrary interop) {
|
||||
InteropLibrary interop,
|
||||
TypesLibrary typesLibrary,
|
||||
MethodResolverNode methodResolverNode) {
|
||||
super(toTextNode, problemBehavior, interop);
|
||||
assert compareFunc != null;
|
||||
assert comparator != null;
|
||||
this.comparator = comparator;
|
||||
this.state = state;
|
||||
this.ascending = ascending;
|
||||
this.compareFunc = checkAndConvertByFunc(compareFunc);
|
||||
this.compareFunc = checkAndConvertByFunc(compareFunc, typesLibrary, methodResolverNode);
|
||||
if (interop.isNull(onFunc)) {
|
||||
this.hasCustomOnFunc = false;
|
||||
this.onFunc = null;
|
||||
} else {
|
||||
this.hasCustomOnFunc = true;
|
||||
this.onFunc = checkAndConvertOnFunc(onFunc);
|
||||
this.onFunc = checkAndConvertOnFunc(onFunc, typesLibrary, methodResolverNode);
|
||||
}
|
||||
this.callNode = callNode;
|
||||
this.less = less;
|
||||
@ -709,19 +806,19 @@ public abstract class SortVectorNode extends Node {
|
||||
Object yConverted;
|
||||
if (hasCustomOnFunc) {
|
||||
// onFunc cannot have `self` argument, we assume it has just one argument.
|
||||
xConverted = callNode.executeDispatch(onFunc, null, state, new Object[] {x});
|
||||
yConverted = callNode.executeDispatch(onFunc, null, state, new Object[] {y});
|
||||
xConverted = callNode.executeDispatch(onFunc.get(x), null, state, new Object[]{x});
|
||||
yConverted = callNode.executeDispatch(onFunc.get(y), null, state, new Object[]{y});
|
||||
} else {
|
||||
xConverted = x;
|
||||
yConverted = y;
|
||||
}
|
||||
Object[] args;
|
||||
if (hasFunctionSelfArgument(compareFunc)) {
|
||||
if (compareFunc.hasFunctionSelfArgument(xConverted)) {
|
||||
args = new Object[] {comparator, xConverted, yConverted};
|
||||
} else {
|
||||
args = new Object[] {xConverted, yConverted};
|
||||
}
|
||||
Object res = callNode.executeDispatch(compareFunc, null, state, args);
|
||||
Object res = callNode.executeDispatch(compareFunc.get(xConverted), null, state, args);
|
||||
if (res == less) {
|
||||
return ascending ? -1 : 1;
|
||||
} else if (res == equal) {
|
||||
@ -738,43 +835,43 @@ public abstract class SortVectorNode extends Node {
|
||||
}
|
||||
}
|
||||
|
||||
private boolean hasFunctionSelfArgument(Function function) {
|
||||
if (function.getSchema().getArgumentsCount() > 0) {
|
||||
return function.getSchema().getArgumentInfos()[0].getName().equals("self");
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks value given for {@code by} parameter and converts it to {@link Function}. Throw a
|
||||
* dataflow error otherwise.
|
||||
*/
|
||||
private Function checkAndConvertByFunc(Object byFuncObj) {
|
||||
private Compare checkAndConvertByFunc(Object byFuncObj, TypesLibrary typesLibrary, MethodResolverNode methodResolverNode) {
|
||||
return checkAndConvertFunction(
|
||||
byFuncObj, "Unsupported argument for `by`, expected a method with two arguments", 2, 3);
|
||||
byFuncObj, "Unsupported argument for `by`, expected a method with two arguments", 2, 3,
|
||||
typesLibrary, methodResolverNode);
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks the value given for {@code on} parameter and converts it to {@link Function}. Throws a
|
||||
* dataflow error otherwise.
|
||||
*/
|
||||
private Function checkAndConvertOnFunc(Object onFuncObj) {
|
||||
private Compare checkAndConvertOnFunc(Object onFuncObj, TypesLibrary typesLibrary, MethodResolverNode methodResolverNode) {
|
||||
return checkAndConvertFunction(
|
||||
onFuncObj, "Unsupported argument for `on`, expected a method with one argument", 1, 1);
|
||||
onFuncObj, "Unsupported argument for `on`, expected a method with one argument", 1, 1,
|
||||
typesLibrary, methodResolverNode);
|
||||
}
|
||||
|
||||
/**
|
||||
* @param minArgCount Minimal count of arguments without a default value.
|
||||
* @param maxArgCount Maximal count of argument without a default value.
|
||||
* @param methodResolverNode node for resolving unresolved symbols.
|
||||
* @param typesLibrary types library for resolving the dispatch type for unresolved symbols.
|
||||
*/
|
||||
private Function checkAndConvertFunction(
|
||||
Object funcObj, String errMsg, int minArgCount, int maxArgCount) {
|
||||
private Compare checkAndConvertFunction(
|
||||
Object funcObj, String errMsg, int minArgCount, int maxArgCount,
|
||||
TypesLibrary typesLibrary, MethodResolverNode methodResolverNode) {
|
||||
if (funcObj instanceof UnresolvedSymbol unresolved) {
|
||||
return new CompareFromUnresolvedSymbol(unresolved, methodResolverNode, typesLibrary);
|
||||
}
|
||||
var err = new IllegalArgumentException(errMsg + ", got " + funcObj);
|
||||
if (funcObj instanceof Function func) {
|
||||
var argCount = getNumberOfNonDefaultArguments(func);
|
||||
if (minArgCount <= argCount && argCount <= maxArgCount) {
|
||||
return func;
|
||||
return new CompareFromFunction(func);
|
||||
} else {
|
||||
throw err;
|
||||
}
|
||||
|
@ -49,6 +49,11 @@ make_partially_sorted_vec n =
|
||||
run_length.put (run_length.get - 1)
|
||||
num
|
||||
|
||||
type Int
|
||||
Value v
|
||||
|
||||
identity self = self
|
||||
|
||||
|
||||
# The Benchmarks ==============================================================
|
||||
|
||||
@ -56,6 +61,7 @@ bench =
|
||||
sorted_vec = make_sorted_ascending_vec vector_size
|
||||
partially_sorted_vec = make_partially_sorted_vec vector_size
|
||||
random_vec = Utils.make_random_vec vector_size
|
||||
random_vec_wrapped = random_vec.map (v -> Int.Value v)
|
||||
projection = x -> x % 10
|
||||
comparator = l -> r -> Ordering.compare l r
|
||||
|
||||
@ -66,6 +72,8 @@ bench =
|
||||
Bench.measure (random_vec.sort) "Random Elements Ascending" iter_size num_iterations
|
||||
Bench.measure (random_vec.sort Sort_Direction.Descending) "Random Elements Descending" iter_size num_iterations
|
||||
Bench.measure (random_vec.sort on=projection) "Sorting with a Custom Projection" iter_size num_iterations
|
||||
Bench.measure (random_vec_wrapped.sort on=(_.identity)) "Sorting with an identity function" iter_size num_iterations
|
||||
Bench.measure (random_vec_wrapped.sort on=(.identity)) "Sorting with an (unresolved) identity function" iter_size num_iterations
|
||||
Bench.measure (random_vec.sort by=comparator) "Sorting with the Default_Ordered_Comparator" iter_size num_iterations
|
||||
|
||||
main = bench
|
||||
|
@ -579,6 +579,7 @@ type_spec name alter = Test.group name <|
|
||||
small_vec = alter [T.Value 1 8, T.Value 1 3, T.Value -20 0, T.Value -1 1, T.Value -1 10, T.Value 4 0]
|
||||
small_expected = [T.Value -20 0, T.Value 4 0, T.Value -1 1, T.Value 1 3, T.Value 1 8, T.Value -1 10]
|
||||
small_vec.sort (on = _.b) . should_equal small_expected
|
||||
small_vec.sort (on = .b) . should_equal small_expected
|
||||
|
||||
Test.specify "should be able to use a custom compare function" <|
|
||||
small_vec = alter [2, 7, -3, 383, -392, 28, -90]
|
||||
|
Loading…
Reference in New Issue
Block a user