HashMapNode supports atoms with custom comparators (#7165)

Add proper handling for atoms with custom comparators into the hashing machinery.
This commit is contained in:
Pavel Marek 2023-06-30 10:57:36 +02:00 committed by GitHub
parent cb9d4c4607
commit ebee8700ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 146 additions and 12 deletions

View File

@ -46,7 +46,7 @@ public abstract class EqualsAtomNode extends Node {
},
limit = "10")
@ExplodeLoop
boolean equalsAtoms(
boolean equalsAtomsWithDefaultComparator(
Atom self,
Atom other,
@Cached("self.getConstructor()") AtomConstructor selfCtorCached,
@ -81,7 +81,7 @@ public abstract class EqualsAtomNode extends Node {
"cachedComparator != null",
},
limit = "10")
boolean equalsAtoms(
boolean equalsAtomsWithCustomComparator(
Atom self,
Atom other,
@Cached("self.getConstructor()") AtomConstructor selfCtorCached,
@ -101,11 +101,25 @@ public abstract class EqualsAtomNode extends Node {
}
@CompilerDirectives.TruffleBoundary
@Specialization(replaces = "equalsAtoms")
@Specialization(
replaces = {"equalsAtomsWithDefaultComparator", "equalsAtomsWithCustomComparator"})
boolean equalsAtomsUncached(Atom self, Atom other) {
if (self.getConstructor() != other.getConstructor()) {
return false;
}
Type customComparator = CustomComparatorNode.getUncached().execute(self);
if (customComparator != null) {
Function compareFunc = findCompareMethod(customComparator);
var invokeFuncNode = invokeCompareNode(compareFunc);
return equalsAtomsWithCustomComparator(
self,
other,
self.getConstructor(),
CustomComparatorNode.getUncached(),
customComparator,
compareFunc,
invokeFuncNode);
}
Object[] selfFields = StructsLibrary.getUncached().getFields(self);
Object[] otherFields = StructsLibrary.getUncached().getFields(other);
if (selfFields.length != otherFields.length) {

View File

@ -227,9 +227,9 @@ public abstract class EqualsNode extends Node {
boolean equalsAtoms(
Atom self,
Atom other,
@Cached EqualsAtomNode equalsNode,
@Cached EqualsAtomNode equalsAtomNode,
@Cached IsSameObjectNode isSameObjectNode) {
return isSameObjectNode.execute(self, other) || equalsNode.execute(self, other);
return isSameObjectNode.execute(self, other) || equalsAtomNode.execute(self, other);
}
@Specialization(guards = "isNotPrimitive(self, other, interop, warnings)")

View File

@ -27,23 +27,31 @@ import java.util.Arrays;
import org.enso.interpreter.dsl.AcceptsError;
import org.enso.interpreter.dsl.BuiltinMethod;
import org.enso.interpreter.node.callable.InvokeCallableNode.ArgumentsExecutionMode;
import org.enso.interpreter.node.callable.InvokeCallableNode.DefaultsExecutionMode;
import org.enso.interpreter.node.callable.dispatch.InvokeFunctionNode;
import org.enso.interpreter.node.expression.builtin.number.utils.BigIntegerOps;
import org.enso.interpreter.node.expression.builtin.ordering.CustomComparatorNode;
import org.enso.interpreter.node.expression.builtin.ordering.CustomComparatorNodeGen;
import org.enso.interpreter.node.expression.builtin.ordering.HashCallbackNode;
import org.enso.interpreter.runtime.EnsoContext;
import org.enso.interpreter.runtime.Module;
import org.enso.interpreter.runtime.callable.UnresolvedConversion;
import org.enso.interpreter.runtime.callable.UnresolvedSymbol;
import org.enso.interpreter.runtime.callable.argument.CallArgumentInfo;
import org.enso.interpreter.runtime.callable.atom.Atom;
import org.enso.interpreter.runtime.callable.atom.AtomConstructor;
import org.enso.interpreter.runtime.callable.atom.StructsLibrary;
import org.enso.interpreter.runtime.callable.function.Function;
import org.enso.interpreter.runtime.data.EnsoFile;
import org.enso.interpreter.runtime.data.Type;
import org.enso.interpreter.runtime.data.text.Text;
import org.enso.interpreter.runtime.error.PanicException;
import org.enso.interpreter.runtime.error.WarningsLibrary;
import org.enso.interpreter.runtime.library.dispatch.TypesLibrary;
import org.enso.interpreter.runtime.number.EnsoBigInteger;
import org.enso.interpreter.runtime.scope.ModuleScope;
import org.enso.interpreter.runtime.state.State;
/**
* Implements {@code hash_code} functionality.
@ -194,10 +202,11 @@ public abstract class HashCodeNode extends Node {
}
@Specialization(guards = {
"atomCtorCached == atom.getConstructor()"
}, limit = "5")
"atomCtorCached == atom.getConstructor()",
"customComparatorNode.execute(atom) == null",
}, limit = "10")
@ExplodeLoop
long hashCodeForAtom(
long hashCodeForAtomWithDefaultComparator(
Atom atom,
@Cached("atom.getConstructor()") AtomConstructor atomCtorCached,
@Cached("atomCtorCached.getFields().length") int fieldsLenCached,
@ -233,13 +242,76 @@ public abstract class HashCodeNode extends Node {
return atomHashCode;
}
@Specialization(
guards = {
"atomCtorCached == atom.getConstructor()",
"cachedComparator != null"
}
)
long hashCodeForAtomWithCustomComparator(
Atom atom,
@Cached("atom.getConstructor()") AtomConstructor atomCtorCached,
@Cached CustomComparatorNode customComparatorNode,
@CachedLibrary(limit = "5") InteropLibrary interop,
@Cached(value = "customComparatorNode.execute(atom)") Type cachedComparator,
@Cached(value = "findHashMethod(cachedComparator)", allowUncached = true)
Function compareMethod,
@Cached(value = "createInvokeNode(compareMethod)") InvokeFunctionNode invokeFunctionNode
) {
var ctx = EnsoContext.get(this);
var args = new Object[] { cachedComparator, atom};
var result = invokeFunctionNode.execute(compareMethod, null, State.create(ctx), args);
if (!interop.isNumber(result)) {
throw new PanicException("Custom comparator must return a number", this);
} else {
try {
return interop.asLong(result);
} catch (UnsupportedMessageException e) {
throw new IllegalStateException(e);
}
}
}
@TruffleBoundary
@Specialization(replaces = "hashCodeForAtom")
static Function findHashMethod(Type comparator) {
var fn = comparator.getDefinitionScope().getMethods().get(comparator).get("hash");
if (fn == null) {
throw new AssertionError("No hash method for type " + comparator);
}
return fn;
}
static InvokeFunctionNode createInvokeNode(Function compareFn) {
CallArgumentInfo[] argsInfo = new CallArgumentInfo[compareFn.getSchema().getArgumentsCount()];
for (int i = 0; i < argsInfo.length; i++) {
var argDef = compareFn.getSchema().getArgumentInfos()[i];
argsInfo[i] = new CallArgumentInfo(argDef.getName());
}
return InvokeFunctionNode.build(
argsInfo, DefaultsExecutionMode.EXECUTE, ArgumentsExecutionMode.EXECUTE);
}
@TruffleBoundary
@Specialization(replaces = {"hashCodeForAtomWithDefaultComparator", "hashCodeForAtomWithCustomComparator"})
long hashCodeForAtomUncached(Atom atom) {
if (atom.getHashCode() != null) {
return atom.getHashCode();
}
Type customComparator = CustomComparatorNode.getUncached().execute(atom);
if (customComparator != null) {
Function compareMethod = findHashMethod(customComparator);
return hashCodeForAtomWithCustomComparator(
atom,
atom.getConstructor(),
CustomComparatorNodeGen.getUncached(),
InteropLibrary.getFactory().getUncached(),
customComparator,
compareMethod,
createInvokeNode(compareMethod)
);
}
Object[] fields = StructsLibrary.getUncached().getFields(atom);
int[] hashes = new int[fields.length + 1];
for (int i = 0; i < fields.length; i++) {

View File

@ -165,14 +165,15 @@ public final class EnsoHashMap implements TruffleObject {
@ExportMessage
@TruffleBoundary
Object toDisplayString(boolean allowSideEffects) {
Object toDisplayString(
boolean allowSideEffects, @CachedLibrary(limit = "5") InteropLibrary interop) {
var sb = new StringBuilder();
sb.append("{");
boolean empty = true;
for (StorageEntry entry : mapBuilder.getStorage().getValues()) {
if (isEntryInThisMap(entry)) {
empty = false;
sb.append(entry.key()).append("=").append(entry.value()).append(", ");
sb.append(entryToString(entry, interop)).append(", ");
}
}
if (!empty) {
@ -185,7 +186,26 @@ public final class EnsoHashMap implements TruffleObject {
@Override
public String toString() {
return (String) toDisplayString(true);
// We are not using uncached InteropLibrary in this method, as it may substantially
// slow down Java debugger.
return (String) toDisplayString(true, null);
}
private static String entryToString(StorageEntry entry, InteropLibrary interop) {
String keyStr;
String valStr;
if (interop != null) {
try {
keyStr = interop.asString(interop.toDisplayString(entry.key()));
valStr = interop.asString(interop.toDisplayString(entry.value()));
} catch (UnsupportedMessageException e) {
throw new IllegalStateException("Unreachable", e);
}
} else {
keyStr = entry.key().toString();
valStr = entry.value().toString();
}
return keyStr + "=" + valStr;
}
private boolean isEntryInThisMap(StorageEntry entry) {

View File

@ -18,6 +18,20 @@ type My_Nan_Comparator
Comparable.from (_:My_Nan) = My_Nan_Comparator
type My_Key
Value hash_code:Integer value:Text idx:Integer
type My_Key_Comparator
# Comparison ignores idx field
compare x y =
if x.hash_code != y.hash_code then Nothing else
if x.value == y.value then Ordering.Equal else Nothing
hash x = x.hash_code
Comparable.from (_:My_Key) = My_Key_Comparator
foreign js js_str str = """
return new String(str)
@ -193,6 +207,20 @@ spec =
m3.size . should_equal 2
m3.get k . should_equal 30
Test.specify "should support atom with custom comparators with complicated hash method" <|
keys = 0.up_to 500 . map ix->
value = ["A", "B", "C", "D", "E"].at (ix % 5)
hash_code = Comparable.from value . hash value
My_Key.Value hash_code value ix
distinct_keys = keys.fold Map.empty acc_map->
item->
acc_map.insert item True
distinct_keys.size . should_equal 5
distinct_key_values = keys.map (_.value) . fold Map.empty acc_map->
item->
acc_map.insert item True
distinct_key_values.size . should_equal 5
Test.specify "should handle keys with standard equality semantics" <|
map = Map.singleton 2 "Hello"
(map.get 2).should_equal "Hello"