mirror of
https://github.com/enso-org/enso.git
synced 2024-12-23 23:31:35 +03:00
HashMapNode supports atoms with custom comparators (#7165)
Add proper handling for atoms with custom comparators into the hashing machinery.
This commit is contained in:
parent
cb9d4c4607
commit
ebee8700ce
@ -46,7 +46,7 @@ public abstract class EqualsAtomNode extends Node {
|
||||
},
|
||||
limit = "10")
|
||||
@ExplodeLoop
|
||||
boolean equalsAtoms(
|
||||
boolean equalsAtomsWithDefaultComparator(
|
||||
Atom self,
|
||||
Atom other,
|
||||
@Cached("self.getConstructor()") AtomConstructor selfCtorCached,
|
||||
@ -81,7 +81,7 @@ public abstract class EqualsAtomNode extends Node {
|
||||
"cachedComparator != null",
|
||||
},
|
||||
limit = "10")
|
||||
boolean equalsAtoms(
|
||||
boolean equalsAtomsWithCustomComparator(
|
||||
Atom self,
|
||||
Atom other,
|
||||
@Cached("self.getConstructor()") AtomConstructor selfCtorCached,
|
||||
@ -101,11 +101,25 @@ public abstract class EqualsAtomNode extends Node {
|
||||
}
|
||||
|
||||
@CompilerDirectives.TruffleBoundary
|
||||
@Specialization(replaces = "equalsAtoms")
|
||||
@Specialization(
|
||||
replaces = {"equalsAtomsWithDefaultComparator", "equalsAtomsWithCustomComparator"})
|
||||
boolean equalsAtomsUncached(Atom self, Atom other) {
|
||||
if (self.getConstructor() != other.getConstructor()) {
|
||||
return false;
|
||||
}
|
||||
Type customComparator = CustomComparatorNode.getUncached().execute(self);
|
||||
if (customComparator != null) {
|
||||
Function compareFunc = findCompareMethod(customComparator);
|
||||
var invokeFuncNode = invokeCompareNode(compareFunc);
|
||||
return equalsAtomsWithCustomComparator(
|
||||
self,
|
||||
other,
|
||||
self.getConstructor(),
|
||||
CustomComparatorNode.getUncached(),
|
||||
customComparator,
|
||||
compareFunc,
|
||||
invokeFuncNode);
|
||||
}
|
||||
Object[] selfFields = StructsLibrary.getUncached().getFields(self);
|
||||
Object[] otherFields = StructsLibrary.getUncached().getFields(other);
|
||||
if (selfFields.length != otherFields.length) {
|
||||
|
@ -227,9 +227,9 @@ public abstract class EqualsNode extends Node {
|
||||
boolean equalsAtoms(
|
||||
Atom self,
|
||||
Atom other,
|
||||
@Cached EqualsAtomNode equalsNode,
|
||||
@Cached EqualsAtomNode equalsAtomNode,
|
||||
@Cached IsSameObjectNode isSameObjectNode) {
|
||||
return isSameObjectNode.execute(self, other) || equalsNode.execute(self, other);
|
||||
return isSameObjectNode.execute(self, other) || equalsAtomNode.execute(self, other);
|
||||
}
|
||||
|
||||
@Specialization(guards = "isNotPrimitive(self, other, interop, warnings)")
|
||||
|
@ -27,23 +27,31 @@ import java.util.Arrays;
|
||||
|
||||
import org.enso.interpreter.dsl.AcceptsError;
|
||||
import org.enso.interpreter.dsl.BuiltinMethod;
|
||||
import org.enso.interpreter.node.callable.InvokeCallableNode.ArgumentsExecutionMode;
|
||||
import org.enso.interpreter.node.callable.InvokeCallableNode.DefaultsExecutionMode;
|
||||
import org.enso.interpreter.node.callable.dispatch.InvokeFunctionNode;
|
||||
import org.enso.interpreter.node.expression.builtin.number.utils.BigIntegerOps;
|
||||
import org.enso.interpreter.node.expression.builtin.ordering.CustomComparatorNode;
|
||||
import org.enso.interpreter.node.expression.builtin.ordering.CustomComparatorNodeGen;
|
||||
import org.enso.interpreter.node.expression.builtin.ordering.HashCallbackNode;
|
||||
import org.enso.interpreter.runtime.EnsoContext;
|
||||
import org.enso.interpreter.runtime.Module;
|
||||
import org.enso.interpreter.runtime.callable.UnresolvedConversion;
|
||||
import org.enso.interpreter.runtime.callable.UnresolvedSymbol;
|
||||
import org.enso.interpreter.runtime.callable.argument.CallArgumentInfo;
|
||||
import org.enso.interpreter.runtime.callable.atom.Atom;
|
||||
import org.enso.interpreter.runtime.callable.atom.AtomConstructor;
|
||||
import org.enso.interpreter.runtime.callable.atom.StructsLibrary;
|
||||
import org.enso.interpreter.runtime.callable.function.Function;
|
||||
import org.enso.interpreter.runtime.data.EnsoFile;
|
||||
import org.enso.interpreter.runtime.data.Type;
|
||||
import org.enso.interpreter.runtime.data.text.Text;
|
||||
import org.enso.interpreter.runtime.error.PanicException;
|
||||
import org.enso.interpreter.runtime.error.WarningsLibrary;
|
||||
import org.enso.interpreter.runtime.library.dispatch.TypesLibrary;
|
||||
import org.enso.interpreter.runtime.number.EnsoBigInteger;
|
||||
import org.enso.interpreter.runtime.scope.ModuleScope;
|
||||
import org.enso.interpreter.runtime.state.State;
|
||||
|
||||
/**
|
||||
* Implements {@code hash_code} functionality.
|
||||
@ -194,10 +202,11 @@ public abstract class HashCodeNode extends Node {
|
||||
}
|
||||
|
||||
@Specialization(guards = {
|
||||
"atomCtorCached == atom.getConstructor()"
|
||||
}, limit = "5")
|
||||
"atomCtorCached == atom.getConstructor()",
|
||||
"customComparatorNode.execute(atom) == null",
|
||||
}, limit = "10")
|
||||
@ExplodeLoop
|
||||
long hashCodeForAtom(
|
||||
long hashCodeForAtomWithDefaultComparator(
|
||||
Atom atom,
|
||||
@Cached("atom.getConstructor()") AtomConstructor atomCtorCached,
|
||||
@Cached("atomCtorCached.getFields().length") int fieldsLenCached,
|
||||
@ -233,13 +242,76 @@ public abstract class HashCodeNode extends Node {
|
||||
return atomHashCode;
|
||||
}
|
||||
|
||||
@Specialization(
|
||||
guards = {
|
||||
"atomCtorCached == atom.getConstructor()",
|
||||
"cachedComparator != null"
|
||||
}
|
||||
)
|
||||
long hashCodeForAtomWithCustomComparator(
|
||||
Atom atom,
|
||||
@Cached("atom.getConstructor()") AtomConstructor atomCtorCached,
|
||||
@Cached CustomComparatorNode customComparatorNode,
|
||||
@CachedLibrary(limit = "5") InteropLibrary interop,
|
||||
@Cached(value = "customComparatorNode.execute(atom)") Type cachedComparator,
|
||||
@Cached(value = "findHashMethod(cachedComparator)", allowUncached = true)
|
||||
Function compareMethod,
|
||||
@Cached(value = "createInvokeNode(compareMethod)") InvokeFunctionNode invokeFunctionNode
|
||||
) {
|
||||
var ctx = EnsoContext.get(this);
|
||||
var args = new Object[] { cachedComparator, atom};
|
||||
var result = invokeFunctionNode.execute(compareMethod, null, State.create(ctx), args);
|
||||
if (!interop.isNumber(result)) {
|
||||
throw new PanicException("Custom comparator must return a number", this);
|
||||
} else {
|
||||
try {
|
||||
return interop.asLong(result);
|
||||
} catch (UnsupportedMessageException e) {
|
||||
throw new IllegalStateException(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@TruffleBoundary
|
||||
@Specialization(replaces = "hashCodeForAtom")
|
||||
static Function findHashMethod(Type comparator) {
|
||||
var fn = comparator.getDefinitionScope().getMethods().get(comparator).get("hash");
|
||||
if (fn == null) {
|
||||
throw new AssertionError("No hash method for type " + comparator);
|
||||
}
|
||||
return fn;
|
||||
}
|
||||
|
||||
static InvokeFunctionNode createInvokeNode(Function compareFn) {
|
||||
CallArgumentInfo[] argsInfo = new CallArgumentInfo[compareFn.getSchema().getArgumentsCount()];
|
||||
for (int i = 0; i < argsInfo.length; i++) {
|
||||
var argDef = compareFn.getSchema().getArgumentInfos()[i];
|
||||
argsInfo[i] = new CallArgumentInfo(argDef.getName());
|
||||
}
|
||||
return InvokeFunctionNode.build(
|
||||
argsInfo, DefaultsExecutionMode.EXECUTE, ArgumentsExecutionMode.EXECUTE);
|
||||
}
|
||||
|
||||
@TruffleBoundary
|
||||
@Specialization(replaces = {"hashCodeForAtomWithDefaultComparator", "hashCodeForAtomWithCustomComparator"})
|
||||
long hashCodeForAtomUncached(Atom atom) {
|
||||
if (atom.getHashCode() != null) {
|
||||
return atom.getHashCode();
|
||||
}
|
||||
|
||||
Type customComparator = CustomComparatorNode.getUncached().execute(atom);
|
||||
if (customComparator != null) {
|
||||
Function compareMethod = findHashMethod(customComparator);
|
||||
return hashCodeForAtomWithCustomComparator(
|
||||
atom,
|
||||
atom.getConstructor(),
|
||||
CustomComparatorNodeGen.getUncached(),
|
||||
InteropLibrary.getFactory().getUncached(),
|
||||
customComparator,
|
||||
compareMethod,
|
||||
createInvokeNode(compareMethod)
|
||||
);
|
||||
}
|
||||
|
||||
Object[] fields = StructsLibrary.getUncached().getFields(atom);
|
||||
int[] hashes = new int[fields.length + 1];
|
||||
for (int i = 0; i < fields.length; i++) {
|
||||
|
@ -165,14 +165,15 @@ public final class EnsoHashMap implements TruffleObject {
|
||||
|
||||
@ExportMessage
|
||||
@TruffleBoundary
|
||||
Object toDisplayString(boolean allowSideEffects) {
|
||||
Object toDisplayString(
|
||||
boolean allowSideEffects, @CachedLibrary(limit = "5") InteropLibrary interop) {
|
||||
var sb = new StringBuilder();
|
||||
sb.append("{");
|
||||
boolean empty = true;
|
||||
for (StorageEntry entry : mapBuilder.getStorage().getValues()) {
|
||||
if (isEntryInThisMap(entry)) {
|
||||
empty = false;
|
||||
sb.append(entry.key()).append("=").append(entry.value()).append(", ");
|
||||
sb.append(entryToString(entry, interop)).append(", ");
|
||||
}
|
||||
}
|
||||
if (!empty) {
|
||||
@ -185,7 +186,26 @@ public final class EnsoHashMap implements TruffleObject {
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return (String) toDisplayString(true);
|
||||
// We are not using uncached InteropLibrary in this method, as it may substantially
|
||||
// slow down Java debugger.
|
||||
return (String) toDisplayString(true, null);
|
||||
}
|
||||
|
||||
private static String entryToString(StorageEntry entry, InteropLibrary interop) {
|
||||
String keyStr;
|
||||
String valStr;
|
||||
if (interop != null) {
|
||||
try {
|
||||
keyStr = interop.asString(interop.toDisplayString(entry.key()));
|
||||
valStr = interop.asString(interop.toDisplayString(entry.value()));
|
||||
} catch (UnsupportedMessageException e) {
|
||||
throw new IllegalStateException("Unreachable", e);
|
||||
}
|
||||
} else {
|
||||
keyStr = entry.key().toString();
|
||||
valStr = entry.value().toString();
|
||||
}
|
||||
return keyStr + "=" + valStr;
|
||||
}
|
||||
|
||||
private boolean isEntryInThisMap(StorageEntry entry) {
|
||||
|
@ -18,6 +18,20 @@ type My_Nan_Comparator
|
||||
|
||||
Comparable.from (_:My_Nan) = My_Nan_Comparator
|
||||
|
||||
type My_Key
|
||||
Value hash_code:Integer value:Text idx:Integer
|
||||
|
||||
type My_Key_Comparator
|
||||
# Comparison ignores idx field
|
||||
compare x y =
|
||||
if x.hash_code != y.hash_code then Nothing else
|
||||
if x.value == y.value then Ordering.Equal else Nothing
|
||||
|
||||
hash x = x.hash_code
|
||||
|
||||
Comparable.from (_:My_Key) = My_Key_Comparator
|
||||
|
||||
|
||||
foreign js js_str str = """
|
||||
return new String(str)
|
||||
|
||||
@ -193,6 +207,20 @@ spec =
|
||||
m3.size . should_equal 2
|
||||
m3.get k . should_equal 30
|
||||
|
||||
Test.specify "should support atom with custom comparators with complicated hash method" <|
|
||||
keys = 0.up_to 500 . map ix->
|
||||
value = ["A", "B", "C", "D", "E"].at (ix % 5)
|
||||
hash_code = Comparable.from value . hash value
|
||||
My_Key.Value hash_code value ix
|
||||
distinct_keys = keys.fold Map.empty acc_map->
|
||||
item->
|
||||
acc_map.insert item True
|
||||
distinct_keys.size . should_equal 5
|
||||
distinct_key_values = keys.map (_.value) . fold Map.empty acc_map->
|
||||
item->
|
||||
acc_map.insert item True
|
||||
distinct_key_values.size . should_equal 5
|
||||
|
||||
Test.specify "should handle keys with standard equality semantics" <|
|
||||
map = Map.singleton 2 "Hello"
|
||||
(map.get 2).should_equal "Hello"
|
||||
|
Loading…
Reference in New Issue
Block a user