Performance improvements for Comparators (#5687)

Critical performance improvements after #4067

# Important Notes
- Replace if-then-else expressions in `Any.==` with case expressions.
- Fix caching in `EqualsNode`.
- This includes fixing specializations, along with fallback guard.
This commit is contained in:
Pavel Marek 2023-02-21 01:56:11 +01:00 committed by GitHub
parent 6dddc530b6
commit 58c7ca5401
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 469 additions and 169 deletions

View File

@ -110,9 +110,13 @@ type Any
# host or polyglot values, so we just compare them with the default comparator.
eq_self = Panic.catch No_Such_Conversion (Comparable.from self) _-> Default_Unordered_Comparator
eq_that = Panic.catch No_Such_Conversion (Comparable.from that) _-> Default_Unordered_Comparator
if Meta.is_same_object eq_self Incomparable then False else
case Meta.is_same_object eq_self Incomparable of
True -> False
False ->
similar_type = Meta.is_same_object eq_self eq_that
if similar_type.not then False else
case similar_type of
False -> False
True ->
case eq_self.is_ordered of
True ->
# Comparable.equals_builtin is a hack how to directly access EqualsNode from the

View File

@ -1,5 +1,6 @@
import project.Data.Ordering.Ordering
import project.Data.Ordering.Comparable
import project.Data.Ordering.Incomparable
import project.Data.Ordering.Default_Ordered_Comparator
import project.Data.Text.Text
import project.Data.Locale.Locale
@ -940,7 +941,10 @@ type Integer
parse_builtin text radix = @Builtin_Method "Integer.parse"
Comparable.from (_:Number) = Default_Ordered_Comparator
Comparable.from (that:Number) =
case that.is_nan of
True -> Incomparable
False -> Default_Ordered_Comparator
## UNSTABLE

View File

@ -9,6 +9,7 @@ import project.Error.Unimplemented.Unimplemented
import project.Nothing
import project.Meta
import project.Meta.Atom
import project.Panic.Panic
from project.Data.Boolean import all
## Provides custom ordering, equality check and hash code for types that need it.
@ -165,9 +166,15 @@ type Default_Ordered_Comparator
## Handles only primitive types, not atoms or vectors.
compare : Any -> Any -> Ordering
compare x y =
if Comparable.less_than_builtin x y then Ordering.Less else
if Comparable.equals_builtin x y then Ordering.Equal else
if Comparable.less_than_builtin y x then Ordering.Greater
case Comparable.less_than_builtin x y of
True -> Ordering.Less
False ->
case Comparable.equals_builtin x y of
True -> Ordering.Equal
False ->
case Comparable.less_than_builtin y x of
True -> Ordering.Greater
False -> Panic.throw "Unreachable"
hash : Number -> Integer
hash x = Comparable.hash_builtin x

View File

@ -47,7 +47,7 @@ get key = @Builtin_Method "State.get"
- key: The key with which to associate the new state.
- new_state: The new state to store.
Returns an uninitialized state error if the user tries to read from an
Returns an uninitialized state error if the user tries to put into an
uninitialized slot.
> Example

View File

@ -1,9 +1,9 @@
package org.enso.interpreter.node.expression.builtin.meta;
import com.ibm.icu.text.Normalizer;
import com.oracle.truffle.api.CompilerAsserts;
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.Fallback;
import com.oracle.truffle.api.dsl.GenerateUncached;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.interop.ArityException;
@ -15,18 +15,19 @@ import com.oracle.truffle.api.interop.UnknownKeyException;
import com.oracle.truffle.api.interop.UnsupportedMessageException;
import com.oracle.truffle.api.interop.UnsupportedTypeException;
import com.oracle.truffle.api.library.CachedLibrary;
import com.oracle.truffle.api.nodes.ExplodeLoop;
import com.oracle.truffle.api.nodes.Node;
import com.oracle.truffle.api.profiles.ConditionProfile;
import com.oracle.truffle.api.profiles.LoopConditionProfile;
import java.math.BigInteger;
import java.time.LocalDateTime;
import java.time.ZonedDateTime;
import java.util.Arrays;
import java.util.Map;
import org.enso.interpreter.dsl.AcceptsError;
import org.enso.interpreter.dsl.BuiltinMethod;
import org.enso.interpreter.node.callable.InvokeCallableNode.ArgumentsExecutionMode;
import org.enso.interpreter.node.callable.InvokeCallableNode.DefaultsExecutionMode;
import org.enso.interpreter.node.callable.dispatch.InvokeFunctionNode;
import org.enso.interpreter.node.expression.builtin.number.utils.BigIntegerOps;
import org.enso.interpreter.node.expression.builtin.ordering.HasCustomComparatorNode;
import org.enso.interpreter.runtime.EnsoContext;
import org.enso.interpreter.runtime.Module;
@ -78,38 +79,43 @@ public abstract class EqualsNode extends Node {
@Specialization
boolean equalsBoolean(boolean self, boolean other) {
boolean equalsBoolBool(boolean self, boolean other) {
return self == other;
}
@Specialization
boolean equalsBytes(byte self, byte other) {
boolean equalsBoolDouble(boolean self, double other) {
return false;
}
@Specialization
boolean equalsBoolLong(boolean self, long other) {
return false;
}
@Specialization
boolean equalsBoolBigInt(boolean self, EnsoBigInteger other) {
return false;
}
@Specialization
boolean equalsBoolText(boolean self, Text other) {
return false;
}
@Specialization
boolean equalsByteByte(byte self, byte other) {
return self == other;
}
@Specialization
boolean equalsLong(long self, long other) {
boolean equalsLongLong(long self, long other) {
return self == other;
}
@Specialization
boolean equalsDouble(double self, double other) {
return self == other;
}
@Specialization
boolean equalsLongDouble(long self, double other) {
return (double) self == other;
}
@Specialization
boolean equalsDoubleLong(double self, long other) {
return self == (double) other;
}
@Specialization
boolean equalsIntLong(int self, long other) {
return (long) self == other;
boolean equalsLongBool(long self, boolean other) {
return false;
}
@Specialization
@ -118,10 +124,34 @@ public abstract class EqualsNode extends Node {
}
@Specialization
boolean equalsIntDouble(int self, double other) {
boolean equalsLongDouble(long self, double other) {
return (double) self == other;
}
@Specialization
boolean equalsLongText(long self, Text other) {
return false;
}
@Specialization
boolean equalsDoubleDouble(double self, double other) {
if (Double.isNaN(self) || Double.isNaN(other)) {
return false;
} else {
return self == other;
}
}
@Specialization
boolean equalsDoubleLong(double self, long other) {
return self == (double) other;
}
@Specialization
boolean equalsDoubleBool(double self, boolean other) {
return false;
}
@Specialization
boolean equalsDoubleInt(double self, int other) {
return self == (double) other;
@ -129,7 +159,33 @@ public abstract class EqualsNode extends Node {
@Specialization
@TruffleBoundary
boolean equalsBigInt(EnsoBigInteger self, EnsoBigInteger otherBigInt) {
boolean equalsDoubleBigInt(double self, EnsoBigInteger other) {
return self == other.doubleValue();
}
@Specialization
boolean equalsDoubleText(double self, Text other) {
return false;
}
@Specialization
boolean equalsIntInt(int self, int other) {
return self == other;
}
@Specialization
boolean equalsIntLong(int self, long other) {
return (long) self == other;
}
@Specialization
boolean equalsIntDouble(int self, double other) {
return (double) self == other;
}
@Specialization
@TruffleBoundary
boolean equalsBigIntBigInt(EnsoBigInteger self, EnsoBigInteger otherBigInt) {
return self.equals(otherBigInt);
}
@ -141,8 +197,63 @@ public abstract class EqualsNode extends Node {
@Specialization
@TruffleBoundary
boolean equalsDoubleBigInt(double self, EnsoBigInteger other) {
return self == other.doubleValue();
boolean equalsBigIntLong(EnsoBigInteger self, long other) {
if (BigIntegerOps.fitsInLong(self.getValue())) {
return self.getValue().compareTo(BigInteger.valueOf(other)) == 0;
} else {
return false;
}
}
@Specialization
boolean equalsBigIntBool(EnsoBigInteger self, boolean other) {
return false;
}
@Specialization
boolean equalsBigIntText(EnsoBigInteger self, Text other) {
return false;
}
@Specialization
@TruffleBoundary
boolean equalsLongBigInt(long self, EnsoBigInteger other) {
if (BigIntegerOps.fitsInLong(other.getValue())) {
return BigInteger.valueOf(self).compareTo(other.getValue()) == 0;
} else {
return false;
}
}
@Specialization(limit = "3")
boolean equalsTextText(Text selfText, Text otherText,
@CachedLibrary("selfText") InteropLibrary selfInterop,
@CachedLibrary("otherText") InteropLibrary otherInterop) {
if (selfText.is_normalized() && otherText.is_normalized()) {
return selfText.toString().compareTo(otherText.toString()) == 0;
} else {
return equalsStrings(selfText, otherText, selfInterop, otherInterop);
}
}
@Specialization
boolean equalsTextBool(Text self, boolean other) {
return false;
}
@Specialization
boolean equalsTextLong(Text selfText, long otherLong) {
return false;
}
@Specialization
boolean equalsTextDouble(Text selfText, double otherDouble) {
return false;
}
@Specialization
boolean equalsTextBigInt(Text self, EnsoBigInteger other) {
return false;
}
/**
@ -225,17 +336,6 @@ public abstract class EqualsNode extends Node {
}
}
@Specialization(limit = "3")
boolean equalsTexts(Text selfText, Text otherText,
@CachedLibrary("selfText") InteropLibrary selfInterop,
@CachedLibrary("otherText") InteropLibrary otherInterop) {
if (selfText.is_normalized() && otherText.is_normalized()) {
return selfText.toString().compareTo(otherText.toString()) == 0;
} else {
return equalsStrings(selfText, otherText, selfInterop, otherInterop);
}
}
/** Interop libraries **/
@Specialization(guards = {
@ -249,7 +349,6 @@ public abstract class EqualsNode extends Node {
return selfInterop.isNull(selfNull) && otherInterop.isNull(otherNull);
}
@Specialization(guards = {
"selfInterop.isBoolean(selfBoolean)",
"otherInterop.isBoolean(otherBoolean)"
@ -268,12 +367,8 @@ public abstract class EqualsNode extends Node {
}
@Specialization(guards = {
"!selfInterop.isDate(selfTimeZone)",
"!selfInterop.isTime(selfTimeZone)",
"selfInterop.isTimeZone(selfTimeZone)",
"!otherInterop.isDate(otherTimeZone)",
"!otherInterop.isTime(otherTimeZone)",
"otherInterop.isTimeZone(otherTimeZone)"
"isTimeZone(selfTimeZone, selfInterop)",
"isTimeZone(otherTimeZone, otherInterop)",
}, limit = "3")
boolean equalsTimeZones(Object selfTimeZone, Object otherTimeZone,
@CachedLibrary("selfTimeZone") InteropLibrary selfInterop,
@ -289,12 +384,8 @@ public abstract class EqualsNode extends Node {
@TruffleBoundary
@Specialization(guards = {
"selfInterop.isDate(selfZonedDateTime)",
"selfInterop.isTime(selfZonedDateTime)",
"selfInterop.isTimeZone(selfZonedDateTime)",
"otherInterop.isDate(otherZonedDateTime)",
"otherInterop.isTime(otherZonedDateTime)",
"otherInterop.isTimeZone(otherZonedDateTime)"
"isZonedDateTime(selfZonedDateTime, selfInterop)",
"isZonedDateTime(otherZonedDateTime, otherInterop)",
}, limit = "3")
boolean equalsZonedDateTimes(Object selfZonedDateTime, Object otherZonedDateTime,
@CachedLibrary("selfZonedDateTime") InteropLibrary selfInterop,
@ -318,12 +409,8 @@ public abstract class EqualsNode extends Node {
}
@Specialization(guards = {
"selfInterop.isDate(selfDateTime)",
"selfInterop.isTime(selfDateTime)",
"!selfInterop.isTimeZone(selfDateTime)",
"otherInterop.isDate(otherDateTime)",
"otherInterop.isTime(otherDateTime)",
"!otherInterop.isTimeZone(otherDateTime)"
"isDateTime(selfDateTime, selfInterop)",
"isDateTime(otherDateTime, otherInterop)",
}, limit = "3")
boolean equalsDateTimes(Object selfDateTime, Object otherDateTime,
@CachedLibrary("selfDateTime") InteropLibrary selfInterop,
@ -344,12 +431,8 @@ public abstract class EqualsNode extends Node {
}
@Specialization(guards = {
"selfInterop.isDate(selfDate)",
"!selfInterop.isTime(selfDate)",
"!selfInterop.isTimeZone(selfDate)",
"otherInterop.isDate(otherDate)",
"!otherInterop.isTime(otherDate)",
"!otherInterop.isTimeZone(otherDate)"
"isDate(selfDate, selfInterop)",
"isDate(otherDate, otherInterop)",
}, limit = "3")
boolean equalsDates(Object selfDate, Object otherDate,
@CachedLibrary("selfDate") InteropLibrary selfInterop,
@ -364,12 +447,8 @@ public abstract class EqualsNode extends Node {
}
@Specialization(guards = {
"!selfInterop.isDate(selfTime)",
"selfInterop.isTime(selfTime)",
"!selfInterop.isTimeZone(selfTime)",
"!otherInterop.isDate(otherTime)",
"otherInterop.isTime(otherTime)",
"!otherInterop.isTimeZone(otherTime)"
"isTime(selfTime, selfInterop)",
"isTime(otherTime, otherInterop)",
}, limit = "3")
boolean equalsTimes(Object selfTime, Object otherTime,
@CachedLibrary("selfTime") InteropLibrary selfInterop,
@ -505,21 +584,8 @@ public abstract class EqualsNode extends Node {
}
@Specialization(guards = {
"!isAtom(selfObject)",
"!isAtom(otherObject)",
"!isHostObject(selfObject)",
"!isHostObject(otherObject)",
"interop.hasMembers(selfObject)",
"interop.hasMembers(otherObject)",
"!interop.isDate(selfObject)",
"!interop.isDate(otherObject)",
"!interop.isTime(selfObject)",
"!interop.isTime(otherObject)",
// Objects with types are handled in `equalsTypes` specialization, so we have to
// negate the guards of that specialization here - to make the specializations
// disjunctive.
"!typesLib.hasType(selfObject)",
"!typesLib.hasType(otherObject)",
"isObjectWithMembers(selfObject, interop)",
"isObjectWithMembers(otherObject, interop)",
})
boolean equalsInteropObjectWithMembers(Object selfObject, Object otherObject,
@CachedLibrary(limit = "10") InteropLibrary interop,
@ -584,32 +650,32 @@ public abstract class EqualsNode extends Node {
return nodes;
}
@Specialization
@Specialization(guards = {
"selfCtorCached == self.getConstructor()"
}, limit = "10")
@ExplodeLoop
boolean equalsAtoms(
Atom self,
Atom other,
@Cached LoopConditionProfile loopProfile,
@Cached(value = "createEqualsNodes(equalsNodeCountForFields)", allowUncached = true) EqualsNode[] fieldEqualsNodes,
@Cached ConditionProfile enoughEqualNodesForFieldsProfile,
@Cached("self.getConstructor()") AtomConstructor selfCtorCached,
@Cached(value = "selfCtorCached.getFields().length", allowUncached = true) int fieldsLenCached,
@Cached(value = "createEqualsNodes(fieldsLenCached)", allowUncached = true) EqualsNode[] fieldEqualsNodes,
@Cached ConditionProfile constructorsNotEqualProfile,
@CachedLibrary(limit = "3") StructsLibrary selfStructs,
@CachedLibrary(limit = "3") StructsLibrary otherStructs,
@Cached HasCustomComparatorNode hasCustomComparatorNode,
@Cached InvokeAnyEqualsNode invokeAnyEqualsNode
@Cached InvokeAnyEqualsNode invokeAnyEqualsNode,
@CachedLibrary(limit = "5") StructsLibrary structsLib
) {
if (constructorsNotEqualProfile.profile(
self.getConstructor() != other.getConstructor()
)) {
return false;
}
var selfFields = selfStructs.getFields(self);
var otherFields = otherStructs.getFields(other);
assert selfFields.length == otherFields.length;
var selfFields = structsLib.getFields(self);
var otherFields = structsLib.getFields(other);
assert selfFields.length == otherFields.length : "Constructors are same, atoms should have the same number of fields";
int fieldsSize = selfFields.length;
if (enoughEqualNodesForFieldsProfile.profile(fieldsSize <= equalsNodeCountForFields)) {
loopProfile.profileCounted(fieldsSize);
for (int i = 0; loopProfile.inject(i < fieldsSize); i++) {
CompilerAsserts.partialEvaluationConstant(fieldsLenCached);
for (int i = 0; i < fieldsLenCached; i++) {
boolean fieldsAreEqual;
// We don't check whether `other` has the same type of comparator, that is checked in
// `Any.==` that we invoke here anyway.
@ -620,21 +686,29 @@ public abstract class EqualsNode extends Node {
// custom comparators. EqualsNode cannot deal with custom comparators.
fieldsAreEqual = invokeAnyEqualsNode.execute(selfAtomField, otherAtomField);
} else {
fieldsAreEqual = fieldEqualsNodes[i].execute(selfFields[i], otherFields[i]);
fieldsAreEqual = fieldEqualsNodes[i].execute(
selfFields[i],
otherFields[i]
);
}
if (!fieldsAreEqual) {
return false;
}
}
} else {
return equalsAtomsFieldsUncached(selfFields, otherFields);
}
return true;
}
@TruffleBoundary
private static boolean equalsAtomsFieldsUncached(Object[] selfFields, Object[] otherFields) {
assert selfFields.length == otherFields.length;
@Specialization(replaces = "equalsAtoms")
boolean equalsAtomsUncached(Atom self, Atom other) {
if (!equalsAtomConstructors(self.getConstructor(), other.getConstructor())) {
return false;
}
Object[] selfFields = StructsLibrary.getUncached().getFields(self);
Object[] otherFields = StructsLibrary.getUncached().getFields(other);
if (selfFields.length != otherFields.length) {
return false;
}
for (int i = 0; i < selfFields.length; i++) {
boolean areFieldsSame;
if (selfFields[i] instanceof Atom selfFieldAtom
@ -683,17 +757,136 @@ public abstract class EqualsNode extends Node {
return equalsNode.execute(selfFuncStrRepr, otherFuncStrRepr);
}
@Fallback
@Specialization(guards = "fallbackGuard(left, right, interop)")
@TruffleBoundary
boolean equalsGeneric(Object left, Object right,
@CachedLibrary(limit = "5") InteropLibrary interop,
@CachedLibrary(limit = "5") TypesLibrary typesLib) {
@CachedLibrary(limit = "10") InteropLibrary interop,
@CachedLibrary(limit = "10") TypesLibrary typesLib) {
return left == right
|| interop.isIdentical(left, right, interop)
|| left.equals(right)
|| (isNullOrNothing(left, typesLib, interop) && isNullOrNothing(right, typesLib, interop));
}
// We have to manually specify negation of guards of other specializations, because
// we cannot use @Fallback here. Note that this guard is not precisely the negation of
// all the other guards on purpose.
boolean fallbackGuard(Object left, Object right, InteropLibrary interop) {
if (isPrimitive(left) && isPrimitive(right)) {
return false;
}
if (isHostObject(left) && isHostObject(right)) {
return false;
}
if (isHostFunction(left) && isHostFunction(right)) {
return false;
}
if (left instanceof Atom && right instanceof Atom) {
return false;
}
if (interop.isNull(left) && interop.isNull(right)) {
return false;
}
if (interop.isString(left) && interop.isString(right)) {
return false;
}
if (interop.hasArrayElements(left) && interop.hasArrayElements(right)) {
return false;
}
if (interop.hasHashEntries(left) && interop.hasHashEntries(right)) {
return false;
}
if (isObjectWithMembers(left, interop) && isObjectWithMembers(right, interop)) {
return false;
}
if (isTimeZone(left, interop) && isTimeZone(right, interop)) {
return false;
}
if (isZonedDateTime(left, interop) && isZonedDateTime(right, interop)) {
return false;
}
if (isDateTime(left, interop) && isDateTime(right, interop)) {
return false;
}
if (isDate(left, interop) && isDate(right, interop)) {
return false;
}
if (isTime(left, interop) && isTime(right, interop)) {
return false;
}
if (interop.isDuration(left) && interop.isDuration(right)) {
return false;
}
// For all other cases, fall through to the generic specialization
return true;
}
/**
* Return true iff object is a primitive value used in some of the specializations
* guard. By primitive value we mean any value that can be present in Enso, so,
* for example, not Integer, as that cannot be present in Enso.
* All the primitive types should be handled in their corresponding specializations.
* See {@link org.enso.interpreter.node.expression.builtin.interop.syntax.HostValueToEnsoNode}.
*/
private static boolean isPrimitive(Object object) {
return object instanceof Boolean ||
object instanceof Long ||
object instanceof Double ||
object instanceof EnsoBigInteger ||
object instanceof Text;
}
boolean isTimeZone(Object object, InteropLibrary interop) {
return
!interop.isTime(object) &&
!interop.isDate(object) &&
interop.isTimeZone(object);
}
boolean isZonedDateTime(Object object, InteropLibrary interop) {
return
interop.isTime(object) &&
interop.isDate(object) &&
interop.isTimeZone(object);
}
boolean isDateTime(Object object, InteropLibrary interop) {
return
interop.isTime(object) &&
interop.isDate(object) &&
!interop.isTimeZone(object);
}
boolean isDate(Object object, InteropLibrary interop) {
return
!interop.isTime(object) &&
interop.isDate(object) &&
!interop.isTimeZone(object);
}
boolean isTime(Object object, InteropLibrary interop) {
return
interop.isTime(object) &&
!interop.isDate(object) &&
!interop.isTimeZone(object);
}
boolean isObjectWithMembers(Object object, InteropLibrary interop) {
if (object instanceof Atom) {
return false;
}
if (isHostObject(object)) {
return false;
}
if (interop.isDate(object)) {
return false;
}
if (interop.isTime(object)) {
return false;
}
return interop.hasMembers(object);
}
private boolean isNullOrNothing(Object object, TypesLibrary typesLib, InteropLibrary interop) {
if (typesLib.hasType(object)) {
return typesLib.getType(object) == EnsoContext.get(this).getNothing();
@ -734,11 +927,11 @@ public abstract class EqualsNode extends Node {
@Cached(value = "getAnyEqualsMethod()", allowUncached = true) Function anyEqualsFunc,
@Cached(value = "buildInvokeFuncNodeForAnyEquals()", allowUncached = true) InvokeFunctionNode invokeAnyEqualsNode,
@CachedLibrary(limit = "3") InteropLibrary interop) {
// TODO: Shouldn't Comparable type be the very first argument? (synthetic self)?
Object ret = invokeAnyEqualsNode.execute(
anyEqualsFunc,
null,
State.create(EnsoContext.get(this)),
// TODO: Shouldn't Any type be the very first argument? (synthetic self)?
new Object[]{selfAtom, thatAtom}
);
try {

View File

@ -2,6 +2,7 @@ package org.enso.interpreter.node.expression.builtin.meta;
import com.google.common.base.Objects;
import com.ibm.icu.text.Normalizer2;
import com.oracle.truffle.api.CompilerAsserts;
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.GenerateUncached;
@ -14,6 +15,7 @@ import com.oracle.truffle.api.interop.UnknownIdentifierException;
import com.oracle.truffle.api.interop.UnsupportedMessageException;
import com.oracle.truffle.api.interop.UnsupportedTypeException;
import com.oracle.truffle.api.library.CachedLibrary;
import com.oracle.truffle.api.nodes.ExplodeLoop;
import com.oracle.truffle.api.nodes.Node;
import com.oracle.truffle.api.profiles.ConditionProfile;
import com.oracle.truffle.api.profiles.LoopConditionProfile;
@ -103,7 +105,10 @@ public abstract class HashCodeNode extends Node {
@Specialization
long hashCodeForDouble(double d) {
if (d % 1.0 != 0 || BigIntegerOps.fitsInLong(d)) {
if (Double.isNaN(d)) {
// NaN is Incomparable, just return a "random" constant
return 456879;
} else if (d % 1.0 != 0 || BigIntegerOps.fitsInLong(d)) {
return Double.hashCode(d);
} else {
return bigDoubleHash(d);
@ -127,6 +132,7 @@ public abstract class HashCodeNode extends Node {
@Specialization
long hashCodeForAtomConstructor(AtomConstructor atomConstructor) {
// AtomConstructors are singletons, we take system hash code explicitly.
return System.identityHashCode(atomConstructor);
}
@ -181,23 +187,23 @@ public abstract class HashCodeNode extends Node {
}
}
/** How many {@link HashCodeNode} nodes should be created for fields in atoms. */
static final int hashCodeNodeCountForFields = 10;
static HashCodeNode[] createHashCodeNodes(int size) {
HashCodeNode[] nodes = new HashCodeNode[size];
Arrays.fill(nodes, HashCodeNode.build());
return nodes;
}
@Specialization
@Specialization(guards = {
"atomCtorCached == atom.getConstructor()"
}, limit = "5")
@ExplodeLoop
long hashCodeForAtom(
Atom atom,
@Cached(value = "createHashCodeNodes(hashCodeNodeCountForFields)", allowUncached = true)
@Cached("atom.getConstructor()") AtomConstructor atomCtorCached,
@Cached("atomCtorCached.getFields().length") int fieldsLenCached,
@Cached(value = "createHashCodeNodes(fieldsLenCached)", allowUncached = true)
HashCodeNode[] fieldHashCodeNodes,
@Cached ConditionProfile isHashCodeCached,
@Cached ConditionProfile enoughHashCodeNodesForFields,
@Cached LoopConditionProfile loopProfile,
@CachedLibrary(limit = "10") StructsLibrary structs,
@Cached HasCustomComparatorNode hasCustomComparatorNode,
@Cached HashCallbackNode hashCallbackNode) {
@ -208,22 +214,18 @@ public abstract class HashCodeNode extends Node {
Object[] fields = structs.getFields(atom);
int fieldsCount = fields.length;
CompilerAsserts.partialEvaluationConstant(fieldsLenCached);
// hashes stores hash codes for all fields, and for constructor.
int[] hashes = new int[fieldsCount + 1];
if (enoughHashCodeNodesForFields.profile(fieldsCount <= hashCodeNodeCountForFields)) {
loopProfile.profileCounted(fieldsCount);
for (int i = 0; loopProfile.inject(i < fieldsCount); i++) {
for (int i = 0; i < fieldsLenCached; i++) {
if (fields[i] instanceof Atom atomField && hasCustomComparatorNode.execute(atomField)) {
hashes[i] = (int) hashCallbackNode.execute(atomField);
} else {
hashes[i] = (int) fieldHashCodeNodes[i].execute(fields[i]);
}
}
} else {
hashCodeForAtomFieldsUncached(fields, hashes);
}
int ctorHashCode = System.identityHashCode(atom.getConstructor());
int ctorHashCode = (int) hashCodeForAtomConstructor(atom.getConstructor());
hashes[hashes.length - 1] = ctorHashCode;
int atomHashCode = Arrays.hashCode(hashes);
@ -232,15 +234,29 @@ public abstract class HashCodeNode extends Node {
}
@TruffleBoundary
private void hashCodeForAtomFieldsUncached(Object[] fields, int[] fieldHashes) {
@Specialization(replaces = "hashCodeForAtom")
long hashCodeForAtomUncached(Atom atom) {
if (atom.getHashCode() != null) {
return atom.getHashCode();
}
Object[] fields = StructsLibrary.getUncached().getFields(atom);
int[] hashes = new int[fields.length + 1];
for (int i = 0; i < fields.length; i++) {
if (fields[i] instanceof Atom atomField
&& HasCustomComparatorNode.getUncached().execute(atomField)) {
fieldHashes[i] = (int) HashCallbackNode.getUncached().execute(atomField);
hashes[i] = (int) HashCallbackNode.getUncached().execute(atomField);
} else {
fieldHashes[i] = (int) HashCodeNodeGen.getUncached().execute(fields[i]);
hashes[i] = (int) HashCodeNodeGen.getUncached().execute(fields[i]);
}
}
int ctorHashCode = (int) hashCodeForAtomConstructor(atom.getConstructor());
hashes[hashes.length - 1] = ctorHashCode;
int atomHashCode = Arrays.hashCode(hashes);
atom.setHashCode(atomHashCode);
return atomHashCode;
}
@Specialization(
@ -434,7 +450,10 @@ public abstract class HashCodeNode extends Node {
* Two maps are considered equal, if they have the same entries. Note that we do not care about
* ordering.
*/
@Specialization(guards = "interop.hasHashEntries(selfMap)")
@Specialization(guards = {
"interop.hasHashEntries(selfMap)",
"!interop.hasArrayElements(selfMap)",
})
long hashCodeForMap(
Object selfMap,
@CachedLibrary(limit = "5") InteropLibrary interop,

View File

@ -1,6 +1,7 @@
package org.enso.interpreter.runtime.data;
import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.interop.InteropLibrary;
import com.oracle.truffle.api.interop.InvalidArrayIndexException;
import com.oracle.truffle.api.interop.TruffleObject;
@ -9,6 +10,7 @@ import com.oracle.truffle.api.library.CachedLibrary;
import com.oracle.truffle.api.library.ExportLibrary;
import com.oracle.truffle.api.library.ExportMessage;
import com.oracle.truffle.api.nodes.Node;
import com.oracle.truffle.api.profiles.BranchProfile;
import org.enso.interpreter.dsl.Builtin;
import org.enso.interpreter.runtime.EnsoContext;
import org.enso.interpreter.runtime.error.Warning;
@ -25,7 +27,7 @@ import org.enso.interpreter.runtime.error.WithWarnings;
@Builtin(pkg = "mutable", stdlibName = "Standard.Base.Data.Array.Array")
public final class Array implements TruffleObject {
private final Object[] items;
private @CompilerDirectives.CompilationFinal Boolean withWarnings;
private Boolean withWarnings;
/**
* Creates a new array
@ -75,14 +77,20 @@ public final class Array implements TruffleObject {
* @throws InvalidArrayIndexException when the index is out of bounds.
*/
@ExportMessage
public Object readArrayElement(long index, @CachedLibrary(limit = "3") WarningsLibrary warnings)
public Object readArrayElement(
long index,
@CachedLibrary(limit = "3") WarningsLibrary warnings,
@Cached BranchProfile errProfile,
@Cached BranchProfile hasWarningsProfile)
throws InvalidArrayIndexException, UnsupportedMessageException {
if (index >= items.length || index < 0) {
errProfile.enter();
throw InvalidArrayIndexException.create(index);
}
var v = items[(int) index];
if (this.hasWarnings(warnings)) {
hasWarningsProfile.enter();
Warning[] extracted = this.getWarnings(null, warnings);
if (warnings.hasWarnings(v)) {
v = warnings.removeWarnings(v);

View File

@ -11,7 +11,9 @@ import java.time.ZonedDateTime;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import org.enso.interpreter.node.expression.builtin.interop.syntax.HostValueToEnsoNode;
import org.enso.interpreter.node.expression.builtin.meta.EqualsNode;
import org.enso.interpreter.node.expression.builtin.meta.EqualsNodeGen;
import org.graalvm.polyglot.Context;
import org.graalvm.polyglot.Value;
import org.junit.AfterClass;
@ -28,19 +30,21 @@ public class EqualsTest extends TestBase {
private static Context context;
private static EqualsNode equalsNode;
private static TestRootNode testRootNode;
private static HostValueToEnsoNode hostValueToEnsoNode;
@BeforeClass
public static void initContextAndData() {
context = createDefaultContext();
unwrappedValues = fetchAllUnwrappedValues();
executeInContext(
context,
() -> {
testRootNode = new TestRootNode();
equalsNode = EqualsNode.build();
testRootNode.insertChildren(equalsNode);
hostValueToEnsoNode = HostValueToEnsoNode.build();
testRootNode.insertChildren(equalsNode, hostValueToEnsoNode);
return null;
});
unwrappedValues = fetchAllUnwrappedValues();
}
@AfterClass
@ -74,6 +78,7 @@ public class EqualsTest extends TestBase {
try {
return values.stream()
.map(value -> unwrapValue(context, value))
.map(unwrappedValue -> hostValueToEnsoNode.execute(unwrappedValue))
.collect(Collectors.toList())
.toArray(new Object[] {});
} catch (Exception e) {
@ -105,6 +110,21 @@ public class EqualsTest extends TestBase {
});
}
@Theory
public void equalsNodeCachedIsConsistentWithUncached(Object firstVal, Object secondVal) {
executeInContext(
context,
() -> {
boolean uncachedRes = EqualsNodeGen.getUncached().execute(firstVal, secondVal);
boolean cachedRes = equalsNode.execute(firstVal, secondVal);
assertEquals(
"Result from uncached EqualsNode should be the same as result from its cached variant",
uncachedRes,
cachedRes);
return null;
});
}
/** Test for some specific values, for which we know that they are equal. */
@Test
public void testDateEquality() {

View File

@ -7,8 +7,10 @@ import com.oracle.truffle.api.interop.InteropLibrary;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import org.enso.interpreter.node.expression.builtin.interop.syntax.HostValueToEnsoNode;
import org.enso.interpreter.node.expression.builtin.meta.EqualsNode;
import org.enso.interpreter.node.expression.builtin.meta.HashCodeNode;
import org.enso.interpreter.node.expression.builtin.meta.HashCodeNodeGen;
import org.graalvm.polyglot.Context;
import org.graalvm.polyglot.Value;
import org.junit.AfterClass;
@ -26,20 +28,22 @@ public class HashCodeTest extends TestBase {
private static HashCodeNode hashCodeNode;
private static EqualsNode equalsNode;
private static HostValueToEnsoNode hostValueToEnsoNode;
private static TestRootNode testRootNode;
@BeforeClass
public static void initContextAndData() {
context = createDefaultContext();
// Initialize datapoints here, to make sure that it is initialized just once.
unwrappedValues = fetchAllUnwrappedValues();
executeInContext(context, () -> {
hashCodeNode = HashCodeNode.build();
equalsNode = EqualsNode.build();
hostValueToEnsoNode = HostValueToEnsoNode.build();
testRootNode = new TestRootNode();
testRootNode.insertChildren(hashCodeNode, equalsNode);
testRootNode.insertChildren(hashCodeNode, equalsNode, hostValueToEnsoNode);
return null;
});
// Initialize datapoints here, to make sure that it is initialized just once.
unwrappedValues = fetchAllUnwrappedValues();
}
@AfterClass
@ -79,6 +83,7 @@ public class HashCodeTest extends TestBase {
return values
.stream()
.map(value -> unwrapValue(context, value))
.map(unwrappedValue -> hostValueToEnsoNode.execute(unwrappedValue))
.collect(Collectors.toList())
.toArray(new Object[]{});
} catch (Exception e) {
@ -132,4 +137,18 @@ public class HashCodeTest extends TestBase {
return null;
});
}
@Theory
public void hashCodeCachedNodeIsConsistentWithUncached(Object value) {
executeInContext(context, () -> {
long uncachedRes = HashCodeNodeGen.getUncached().execute(value);
long cachedRes = hashCodeNode.execute(value);
assertEquals(
"Result from cached HashCodeNode should be the same as from its uncached variant",
uncachedRes,
cachedRes
);
return null;
});
}
}

View File

@ -19,6 +19,7 @@ import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TimeZone;
import org.enso.polyglot.MethodNames.Module;
import org.graalvm.polyglot.Context;
import org.graalvm.polyglot.PolyglotException;
import org.graalvm.polyglot.Value;
@ -74,6 +75,32 @@ class ValuesGenerator {
return v;
}
/**
* Converts expressions into values of type described by {@code typeDefs} by concatenating
* everything into a single source.
*
* This method exists so that there are no multiple definitions of a single type.
*
* @param typeDefs Type definitions.
* @param expressions List of expressions - every expression will be converted to a {@link Value}.
* @return List of values converted from the given expressions.
*/
private List<Value> createValuesOfCustomType(String typeDefs, List<String> expressions) {
var sb = new StringBuilder();
sb.append(typeDefs);
sb.append("\n");
for (int i = 0; i < expressions.size(); i++) {
sb.append("var_").append(i).append(" = ").append(expressions.get(i)).append("\n");
}
Value module = ctx.eval("enso", sb.toString());
List<Value> values = new ArrayList<>(expressions.size());
for (int i = 0; i < expressions.size(); i++) {
Value val = module.invokeMember(Module.EVAL_EXPRESSION, "var_" + i);
values.add(val);
}
return values;
}
public Value typeAny() {
return v("typeAny", """
import Standard.Base.Any.Any
@ -521,7 +548,7 @@ class ValuesGenerator {
Nil
Value value
""";
for (var expr : List.of(
var exprs = List.of(
"Node.C2 Node.Nil (Node.Value 42)",
"Node.C2 (Node.Value 42) Node.Nil",
"Node.Nil",
@ -536,9 +563,8 @@ class ValuesGenerator {
"Node.C2 (Node.C2 (Node.C1 Node.Nil) (Node.C1 (Node.C1 Node.Nil))) (Node.C2 (Node.C3 (Node.Nil) (Node.Value 22) (Node.Nil)) (Node.C2 (Node.Value 22) (Node.Nil)))",
"Node.C2 (Node.C2 (Node.C1 Node.Nil) (Node.C1 Node.Nil)) (Node.C2 (Node.C3 (Node.Nil) (Node.Value 22) (Node.Nil)) (Node.C2 (Node.Value 22) (Node.Nil)))",
"Node.C2 (Node.C2 (Node.C1 Node.Nil) (Node.C1 Node.Nil)) (Node.C2 (Node.C3 (Node.Nil) (Node.Nil) (Node.Value 22)) (Node.C2 (Node.Value 22) (Node.Nil)))"
)) {
collect.add(v(null, nodeTypeDef, expr).type());
}
);
collect.addAll(createValuesOfCustomType(nodeTypeDef, exprs));
}
return collect;
}