Refactor OrderMask to avoid memory copying (#8863)

Goal of this PR is to refactor the design of OrderMask and avoid copying arrays or lists wherever possible.
We have removed a few legacy functions which were not being used.

On a poor mans benchmark seems to be quicker (13s vs 16s) and memory usage should be lower.
This commit is contained in:
James Dunkerley 2024-01-26 11:16:16 +00:00 committed by GitHub
parent 5dd2dc1c93
commit 0b6db5797c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 213 additions and 183 deletions

View File

@ -32,6 +32,7 @@ from project.Internal.Column_Format import all
from project.Internal.Java_Exports import make_date_builder_adapter, make_string_builder from project.Internal.Java_Exports import make_date_builder_adapter, make_string_builder
polyglot java import org.enso.base.Time_Utils polyglot java import org.enso.base.Time_Utils
polyglot java import org.enso.table.data.mask.OrderMask
polyglot java import org.enso.table.data.column.operation.cast.CastProblemAggregator polyglot java import org.enso.table.data.column.operation.cast.CastProblemAggregator
polyglot java import org.enso.table.data.column.storage.Storage as Java_Storage polyglot java import org.enso.table.data.column.storage.Storage as Java_Storage
polyglot java import org.enso.table.data.table.Column as Java_Column polyglot java import org.enso.table.data.table.Column as Java_Column
@ -2284,7 +2285,7 @@ type Column
example_reverse = Examples.integer_column.reverse example_reverse = Examples.integer_column.reverse
reverse : Column reverse : Column
reverse self = reverse self =
mask = OrderBuilder.buildReversedMask self.length mask = OrderMask.reverse self.length
Column.Value (self.java_column.applyMask mask) Column.Value (self.java_column.applyMask mask)
## GROUP Standard.Base.Metadata ## GROUP Standard.Base.Metadata

View File

@ -78,7 +78,6 @@ polyglot java import org.enso.table.error.NonUniqueLookupKey
polyglot java import org.enso.table.error.NullValuesInKeyColumns polyglot java import org.enso.table.error.NullValuesInKeyColumns
polyglot java import org.enso.table.error.TooManyColumnsException polyglot java import org.enso.table.error.TooManyColumnsException
polyglot java import org.enso.table.error.UnmatchedRow polyglot java import org.enso.table.error.UnmatchedRow
polyglot java import org.enso.table.operations.OrderBuilder
polyglot java import org.enso.table.parsing.problems.ParseProblemAggregator polyglot java import org.enso.table.parsing.problems.ParseProblemAggregator
## Represents a column-oriented table data structure. ## Represents a column-oriented table data structure.
@ -2423,7 +2422,7 @@ type Table
example_reverse = Examples.inventory_table.reverse example_reverse = Examples.inventory_table.reverse
reverse : Table reverse : Table
reverse self = reverse self =
mask = OrderBuilder.buildReversedMask self.row_count mask = OrderMask.reverse self.row_count
Table.Value <| self.java_table.applyMask mask Table.Value <| self.java_table.applyMask mask
## GROUP Standard.Base.Output ## GROUP Standard.Base.Output

View File

@ -12,7 +12,6 @@ polyglot java import java.lang.IllegalArgumentException
polyglot java import java.time.temporal.UnsupportedTemporalTypeException polyglot java import java.time.temporal.UnsupportedTemporalTypeException
polyglot java import org.enso.table.data.column.storage.Storage as Java_Storage polyglot java import org.enso.table.data.column.storage.Storage as Java_Storage
polyglot java import org.enso.table.data.table.Column as Java_Column polyglot java import org.enso.table.data.table.Column as Java_Column
polyglot java import org.enso.table.operations.OrderBuilder
## PRIVATE ## PRIVATE
Create a formatter for the specified `Value_Type`. Create a formatter for the specified `Value_Type`.

View File

@ -137,7 +137,7 @@ fan_out_to_rows_and_columns table input_column_id function column_names at_least
Column.from_storage column_name output_storage Column.from_storage column_name output_storage
# Build the order mask. # Build the order mask.
order_mask = OrderMask.new (order_mask_positions.to_vector) order_mask = OrderMask.fromArray (order_mask_positions.to_vector)
## Build the new table, replacing the input column with the new output ## Build the new table, replacing the input column with the new output
columns. columns.

View File

@ -4,8 +4,6 @@ import project.Data.Table.Table
import project.Data.Type.Value_Type.Value_Type import project.Data.Type.Value_Type.Value_Type
from project.Internal.Fan_Out import all from project.Internal.Fan_Out import all
polyglot java import org.enso.table.data.mask.OrderMask
## PRIVATE ## PRIVATE
Splits a column of text into a set of new columns. Splits a column of text into a set of new columns.
See `Table.split_to_columns`. See `Table.split_to_columns`.

View File

@ -57,15 +57,9 @@ public class IntArrayBuilder {
* <p>After calling this method, the builder is invalidated and cannot be used anymore. Any usage * <p>After calling this method, the builder is invalidated and cannot be used anymore. Any usage
* of the builder afterwards will result in a {@code NullPointerException}. * of the builder afterwards will result in a {@code NullPointerException}.
*/ */
public int[] unsafeGetStorageAndInvalidateTheBuilder() { public int[] unsafeGetResultAndInvalidate() {
int[] tmp = storage; int[] tmp = storage;
this.storage = null; this.storage = null;
return tmp; return tmp;
} }
public int[] build() {
int[] result = new int[length];
System.arraycopy(storage, 0, result, 0, length);
return result;
}
} }

View File

@ -158,9 +158,9 @@ public class CaseFoldedString {
return new CaseFoldedString( return new CaseFoldedString(
stringBuilder.toString(), stringBuilder.toString(),
grapheme_mapping.unsafeGetStorageAndInvalidateTheBuilder(), grapheme_mapping.unsafeGetResultAndInvalidate(),
codeunit_start_mapping.unsafeGetStorageAndInvalidateTheBuilder(), codeunit_start_mapping.unsafeGetResultAndInvalidate(),
codeunit_end_mapping.unsafeGetStorageAndInvalidateTheBuilder()); codeunit_end_mapping.unsafeGetResultAndInvalidate());
} }
/** /**

View File

@ -203,19 +203,19 @@ public final class BoolStorage extends Storage<Boolean> {
@Override @Override
public BoolStorage applyMask(OrderMask mask) { public BoolStorage applyMask(OrderMask mask) {
Context context = Context.getCurrent(); Context context = Context.getCurrent();
int[] positions = mask.getPositions();
BitSet newNa = new BitSet(); BitSet newNa = new BitSet();
BitSet newVals = new BitSet(); BitSet newVals = new BitSet();
for (int i = 0; i < positions.length; i++) { for (int i = 0; i < mask.length(); i++) {
if (positions[i] == Index.NOT_FOUND || isMissing.get(positions[i])) { int position = mask.get(i);
if (position == Index.NOT_FOUND || isMissing.get(position)) {
newNa.set(i); newNa.set(i);
} else if (values.get(positions[i])) { } else if (values.get(position)) {
newVals.set(i); newVals.set(i);
} }
context.safepoint(); context.safepoint();
} }
return new BoolStorage(newVals, newNa, positions.length, negated); return new BoolStorage(newVals, newNa, mask.length(), negated);
} }
@Override @Override

View File

@ -127,18 +127,13 @@ public abstract class SpecializedStorage<T> extends Storage<T> {
@Override @Override
public SpecializedStorage<T> applyMask(OrderMask mask) { public SpecializedStorage<T> applyMask(OrderMask mask) {
Context context = Context.getCurrent(); Context context = Context.getCurrent();
int[] positions = mask.getPositions(); T[] newData = newUnderlyingArray(mask.length());
T[] newData = newUnderlyingArray(positions.length); for (int i = 0; i < mask.length(); i++) {
for (int i = 0; i < positions.length; i++) { int position = mask.get(i);
if (positions[i] == Index.NOT_FOUND) { newData[i] = position == Index.NOT_FOUND ? null : data[position];
newData[i] = null;
} else {
newData[i] = data[positions[i]];
}
context.safepoint(); context.safepoint();
} }
return newInstance(newData, positions.length); return newInstance(newData, newData.length);
} }
@Override @Override

View File

@ -80,20 +80,20 @@ public abstract class ComputedLongStorage extends AbstractLongStorage {
@Override @Override
public Storage<Long> applyMask(OrderMask mask) { public Storage<Long> applyMask(OrderMask mask) {
int[] positions = mask.getPositions(); long[] newData = new long[mask.length()];
long[] newData = new long[positions.length];
BitSet newMissing = new BitSet(); BitSet newMissing = new BitSet();
Context context = Context.getCurrent(); Context context = Context.getCurrent();
for (int i = 0; i < positions.length; i++) { for (int i = 0; i < mask.length(); i++) {
if (positions[i] == Index.NOT_FOUND) { int position = mask.get(i);
if (position == Index.NOT_FOUND) {
newMissing.set(i); newMissing.set(i);
} else { } else {
newData[i] = getItem(positions[i]); newData[i] = getItem(position);
} }
context.safepoint(); context.safepoint();
} }
return new LongStorage(newData, positions.length, newMissing, getType()); return new LongStorage(newData, newData.length, newMissing, getType());
} }
@Override @Override

View File

@ -100,15 +100,15 @@ public abstract class ComputedNullableLongStorage extends AbstractLongStorage {
@Override @Override
public Storage<Long> applyMask(OrderMask mask) { public Storage<Long> applyMask(OrderMask mask) {
int[] positions = mask.getPositions(); long[] newData = new long[mask.length()];
long[] newData = new long[positions.length];
BitSet newMissing = new BitSet(); BitSet newMissing = new BitSet();
Context context = Context.getCurrent(); Context context = Context.getCurrent();
for (int i = 0; i < positions.length; i++) { for (int i = 0; i < mask.length(); i++) {
if (positions[i] == Index.NOT_FOUND) { int position = mask.get(i);
if (position == Index.NOT_FOUND) {
newMissing.set(i); newMissing.set(i);
} else { } else {
Long item = computeItem(positions[i]); Long item = computeItem(position);
if (item == null) { if (item == null) {
newMissing.set(i); newMissing.set(i);
} else { } else {
@ -118,7 +118,7 @@ public abstract class ComputedNullableLongStorage extends AbstractLongStorage {
context.safepoint(); context.safepoint();
} }
return new LongStorage(newData, positions.length, newMissing, getType()); return new LongStorage(newData, newData.length, newMissing, getType());
} }
@Override @Override

View File

@ -278,20 +278,20 @@ public final class DoubleStorage extends NumericStorage<Double> implements Doubl
@Override @Override
public Storage<Double> applyMask(OrderMask mask) { public Storage<Double> applyMask(OrderMask mask) {
int[] positions = mask.getPositions(); long[] newData = new long[mask.length()];
long[] newData = new long[positions.length];
BitSet newMissing = new BitSet(); BitSet newMissing = new BitSet();
Context context = Context.getCurrent(); Context context = Context.getCurrent();
for (int i = 0; i < positions.length; i++) { for (int i = 0; i < mask.length(); i++) {
if (positions[i] == Index.NOT_FOUND || isMissing.get(positions[i])) { int position = mask.get(i);
if (position == Index.NOT_FOUND || isMissing.get(position)) {
newMissing.set(i); newMissing.set(i);
} else { } else {
newData[i] = data[positions[i]]; newData[i] = data[position];
} }
context.safepoint(); context.safepoint();
} }
return new DoubleStorage(newData, positions.length, newMissing); return new DoubleStorage(newData, newData.length, newMissing);
} }
@Override @Override

View File

@ -191,20 +191,20 @@ public final class LongStorage extends AbstractLongStorage {
@Override @Override
public Storage<Long> applyMask(OrderMask mask) { public Storage<Long> applyMask(OrderMask mask) {
int[] positions = mask.getPositions(); long[] newData = new long[mask.length()];
long[] newData = new long[positions.length];
BitSet newMissing = new BitSet(); BitSet newMissing = new BitSet();
Context context = Context.getCurrent(); Context context = Context.getCurrent();
for (int i = 0; i < positions.length; i++) { for (int i = 0; i < mask.length(); i++) {
if (positions[i] == Index.NOT_FOUND || isMissing.get(positions[i])) { int position = mask.get(i);
if (position == Index.NOT_FOUND || isMissing.get(position)) {
newMissing.set(i); newMissing.set(i);
} else { } else {
newData[i] = data[positions[i]]; newData[i] = data[position];
} }
context.safepoint(); context.safepoint();
} }
return new LongStorage(newData, positions.length, newMissing, type); return new LongStorage(newData, newData.length, newMissing, type);
} }
@Override @Override

View File

@ -1,69 +1,96 @@
package org.enso.table.data.mask; package org.enso.table.data.mask;
import java.util.Arrays; import java.util.function.ToIntFunction;
import java.util.List;
import org.graalvm.polyglot.Context;
/** Describes a storage reordering operator. */ /** Describes a storage reordering operator. */
public class OrderMask { public interface OrderMask {
private final int[] positions; int length();
/** /**
* Creates a new reordering operator, with the specified characteristics. See {@link * Describes the reordering that should happen on the applying storage at the index.
* #getPositions()} for a description of the semantics.
*
* @param positions the positions array, as described by {@link #getPositions()}
*/
public OrderMask(int[] positions) {
this.positions = positions;
}
/**
* Describes the reordering that should happen on the applying storage.
* *
* <p>The resulting storage should contain the {@code positions[i]}-th element of the original * <p>The resulting storage should contain the {@code positions[i]}-th element of the original
* storage at the i-th position. {@code positions[i]} may be equal to {@link * storage at the {@code idx}-th position. It may return {@link
* org.enso.table.data.index.Index.NOT_FOUND}, in which case a missing value should be inserted at * org.enso.table.data.index.Index.NOT_FOUND}, in which case a missing value should be inserted at
* this position. * this position.
*/ */
public int[] getPositions() { int get(int idx);
return positions;
static OrderMask empty() {
return new OrderMaskFromArray(new int[0], 0);
} }
public OrderMask append(OrderMask other) { static OrderMask reverse(int size) {
int[] result = Arrays.copyOf(positions, positions.length + other.positions.length); return new OrderMaskReversed(size);
System.arraycopy(other.positions, 0, result, positions.length, other.positions.length);
return new OrderMask(result);
} }
public static OrderMask empty() { static OrderMask fromArray(int[] positions) {
return new OrderMask(new int[0]); return fromArray(positions, positions.length);
} }
public static OrderMask fromList(List<Integer> positions) { static OrderMask fromArray(int[] positions, int length) {
Context context = Context.getCurrent(); return new OrderMaskFromArray(positions, length);
int[] result = new int[positions.size()]; }
for (int i = 0; i < positions.size(); i++) {
result[i] = positions.get(i); static <T> OrderMask fromObjects(T[] input, ToIntFunction<T> function) {
context.safepoint(); return new OrderMaskGeneric<>(input, function);
}
class OrderMaskFromArray implements OrderMask {
private final int[] positions;
private final int length;
public OrderMaskFromArray(int[] positions, int length) {
this.positions = positions;
this.length = length;
}
@Override
public int length() {
return length;
}
@Override
public int get(int idx) {
return positions[idx];
} }
return new OrderMask(result);
} }
public static OrderMask concat(List<OrderMask> masks) { class OrderMaskGeneric<T> implements OrderMask {
Context context = Context.getCurrent(); private final T[] positions;
int size = 0; private final ToIntFunction<T> function;
for (OrderMask mask : masks) {
size += mask.positions.length; public OrderMaskGeneric(T[] positions, ToIntFunction<T> function) {
context.safepoint(); this.positions = positions;
this.function = function;
} }
int[] result = new int[size];
int offset = 0; @Override
for (OrderMask mask : masks) { public int length() {
System.arraycopy(mask.positions, 0, result, offset, mask.positions.length); return positions.length;
offset += mask.positions.length; }
context.safepoint();
@Override
public int get(int idx) {
return function.applyAsInt(positions[idx]);
}
}
class OrderMaskReversed implements OrderMask {
private final int length;
public OrderMaskReversed(int length) {
this.length = length;
}
@Override
public int length() {
return length;
}
@Override
public int get(int idx) {
return length - idx - 1;
} }
return new OrderMask(result);
} }
} }

View File

@ -222,8 +222,7 @@ public class Table {
context.safepoint(); context.safepoint();
} }
Arrays.sort(keys); Arrays.sort(keys);
int[] positions = Arrays.stream(keys).mapToInt(MultiValueKeyBase::getRowIndex).toArray(); OrderMask mask = OrderMask.fromObjects(keys, MultiValueKeyBase::getRowIndex);
OrderMask mask = new OrderMask(positions);
return this.applyMask(mask); return this.applyMask(mask);
} }

View File

@ -5,16 +5,13 @@ import org.graalvm.polyglot.Context;
public class CrossJoin { public class CrossJoin {
public static JoinResult perform(int leftRowCount, int rightRowCount) { public static JoinResult perform(int leftRowCount, int rightRowCount) {
Context context = Context.getCurrent(); Context context = Context.getCurrent();
JoinResult.BuilderSettings settings = new JoinResult.BuilderSettings(true, true, true); JoinResult.Builder resultBuilder = new JoinResult.Builder(leftRowCount * rightRowCount);
JoinResult.Builder resultBuilder =
new JoinResult.Builder(leftRowCount * rightRowCount, settings);
for (int l = 0; l < leftRowCount; ++l) { for (int l = 0; l < leftRowCount; ++l) {
for (int r = 0; r < rightRowCount; ++r) { for (int r = 0; r < rightRowCount; ++r) {
resultBuilder.addMatchedRowsPair(l, r); resultBuilder.addMatchedRowsPair(l, r);
context.safepoint(); context.safepoint();
} }
} }
return resultBuilder.buildAndInvalidate();
return resultBuilder.build();
} }
} }

View File

@ -1,21 +1,20 @@
package org.enso.table.data.table.join; package org.enso.table.data.table.join;
public enum JoinKind { public enum JoinKind {
INNER, INNER(true, false, false),
FULL, FULL(true, true, true),
LEFT_OUTER, LEFT_OUTER(true, true, false),
RIGHT_OUTER, RIGHT_OUTER(true, false, true),
LEFT_ANTI, LEFT_ANTI(false, true, false),
RIGHT_ANTI; RIGHT_ANTI(false, false, true);
public static JoinResult.BuilderSettings makeSettings(JoinKind joinKind) { public final boolean wantsCommon;
return switch (joinKind) { public final boolean wantsLeftUnmatched;
case INNER -> new JoinResult.BuilderSettings(true, false, false); public final boolean wantsRightUnmatched;
case FULL -> new JoinResult.BuilderSettings(true, true, true);
case LEFT_OUTER -> new JoinResult.BuilderSettings(true, true, false); private JoinKind(boolean wantsCommon, boolean wantsLeftUnmatched, boolean wantsRightUnmatched) {
case RIGHT_OUTER -> new JoinResult.BuilderSettings(true, false, true); this.wantsCommon = wantsCommon;
case LEFT_ANTI -> new JoinResult.BuilderSettings(false, true, false); this.wantsLeftUnmatched = wantsLeftUnmatched;
case RIGHT_ANTI -> new JoinResult.BuilderSettings(false, false, true); this.wantsRightUnmatched = wantsRightUnmatched;
};
} }
} }

View File

@ -3,33 +3,39 @@ package org.enso.table.data.table.join;
import org.enso.base.arrays.IntArrayBuilder; import org.enso.base.arrays.IntArrayBuilder;
import org.enso.table.data.mask.OrderMask; import org.enso.table.data.mask.OrderMask;
public record JoinResult(int[] matchedRowsLeftIndices, int[] matchedRowsRightIndices) { public class JoinResult {
private final int length;
private final int[] leftIndices;
private final int[] rightIndices;
public JoinResult(int[] leftIndices, int[] rightIndices, int length) {
this.length = length;
this.leftIndices = leftIndices;
this.rightIndices = rightIndices;
}
// ** Represents a pair of indices of matched rows. -1 means an unmatched row.*/
public record RowPair(int leftIndex, int rightIndex) {}
public OrderMask getLeftOrderMask() { public OrderMask getLeftOrderMask() {
return new OrderMask(matchedRowsLeftIndices); return OrderMask.fromArray(leftIndices, length);
} }
public OrderMask getRightOrderMask() { public OrderMask getRightOrderMask() {
return new OrderMask(matchedRowsRightIndices); return OrderMask.fromArray(rightIndices, length);
} }
public record BuilderSettings(
boolean wantsCommon, boolean wantsLeftUnmatched, boolean wantsRightUnmatched) {}
public static class Builder { public static class Builder {
IntArrayBuilder leftIndices; IntArrayBuilder leftIndices;
IntArrayBuilder rightIndices; IntArrayBuilder rightIndices;
final BuilderSettings settings; public Builder(int initialCapacity) {
public Builder(int initialCapacity, BuilderSettings settings) {
leftIndices = new IntArrayBuilder(initialCapacity); leftIndices = new IntArrayBuilder(initialCapacity);
rightIndices = new IntArrayBuilder(initialCapacity); rightIndices = new IntArrayBuilder(initialCapacity);
this.settings = settings;
} }
public Builder(BuilderSettings settings) { public Builder() {
this(128, settings); this(128);
} }
public void addMatchedRowsPair(int leftIndex, int rightIndex) { public void addMatchedRowsPair(int leftIndex, int rightIndex) {
@ -47,8 +53,22 @@ public record JoinResult(int[] matchedRowsLeftIndices, int[] matchedRowsRightInd
rightIndices.add(rightIndex); rightIndices.add(rightIndex);
} }
public JoinResult build() { /**
return new JoinResult(leftIndices.build(), rightIndices.build()); * Returns the result of the builder.
*
* <p>This method avoids copying for performance. After calling this method, the builder is
* invalidated and cannot be used anymore. Any usage of the builder afterwards will result in a
* {@code NullPointerException}.
*/
public JoinResult buildAndInvalidate() {
var left = leftIndices;
var right = rightIndices;
leftIndices = null;
rightIndices = null;
return new JoinResult(
left.unsafeGetResultAndInvalidate(),
right.unsafeGetResultAndInvalidate(),
left.getLength());
} }
} }
} }

View File

@ -15,8 +15,6 @@ public interface JoinStrategy {
static JoinStrategy createStrategy(List<JoinCondition> conditions, JoinKind joinKind) { static JoinStrategy createStrategy(List<JoinCondition> conditions, JoinKind joinKind) {
ensureConditionsNotEmpty(conditions); ensureConditionsNotEmpty(conditions);
JoinResult.BuilderSettings builderSettings = JoinKind.makeSettings(joinKind);
List<HashableCondition> hashableConditions = List<HashableCondition> hashableConditions =
conditions.stream() conditions.stream()
.filter(c -> c instanceof HashableCondition) .filter(c -> c instanceof HashableCondition)
@ -31,12 +29,14 @@ public interface JoinStrategy {
if (hashableConditions.isEmpty()) { if (hashableConditions.isEmpty()) {
assert !betweenConditions.isEmpty(); assert !betweenConditions.isEmpty();
return new SortJoin(betweenConditions, builderSettings); return new SortJoin(betweenConditions, joinKind);
} else if (betweenConditions.isEmpty()) { } else if (betweenConditions.isEmpty()) {
return new HashJoin(hashableConditions, new MatchAllStrategy(), builderSettings);
} else {
return new HashJoin( return new HashJoin(
hashableConditions, new SortJoin(betweenConditions, builderSettings), builderSettings); hashableConditions,
joinKind.wantsCommon ? new MatchAllStrategy() : new NoOpStrategy(),
joinKind);
} else {
return new HashJoin(hashableConditions, new SortJoin(betweenConditions, joinKind), joinKind);
} }
} }

View File

@ -15,10 +15,6 @@ public class MatchAllStrategy implements PluggableJoinStrategy {
List<Integer> rightGroup, List<Integer> rightGroup,
JoinResult.Builder resultBuilder, JoinResult.Builder resultBuilder,
ProblemAggregator problemAggregator) { ProblemAggregator problemAggregator) {
if (!resultBuilder.settings.wantsCommon()) {
return;
}
Context context = Context.getCurrent(); Context context = Context.getCurrent();
for (var leftRow : leftGroup) { for (var leftRow : leftGroup) {
for (var rightRow : rightGroup) { for (var rightRow : rightGroup) {

View File

@ -0,0 +1,15 @@
package org.enso.table.data.table.join;
import java.util.List;
import org.enso.table.problems.ProblemAggregator;
public class NoOpStrategy implements PluggableJoinStrategy {
@Override
public void joinSubsets(
List<Integer> leftGroup,
List<Integer> rightGroup,
JoinResult.Builder resultBuilder,
ProblemAggregator problemAggregator) {
return;
}
}

View File

@ -7,6 +7,7 @@ import java.util.List;
import org.enso.base.ObjectComparator; import org.enso.base.ObjectComparator;
import org.enso.table.data.column.storage.Storage; import org.enso.table.data.column.storage.Storage;
import org.enso.table.data.index.OrderedMultiValueKey; import org.enso.table.data.index.OrderedMultiValueKey;
import org.enso.table.data.table.join.JoinKind;
import org.enso.table.data.table.join.JoinResult; import org.enso.table.data.table.join.JoinResult;
import org.enso.table.data.table.join.JoinStrategy; import org.enso.table.data.table.join.JoinStrategy;
import org.enso.table.data.table.join.PluggableJoinStrategy; import org.enso.table.data.table.join.PluggableJoinStrategy;
@ -16,9 +17,9 @@ import org.graalvm.polyglot.Context;
public class SortJoin implements JoinStrategy, PluggableJoinStrategy { public class SortJoin implements JoinStrategy, PluggableJoinStrategy {
public SortJoin(List<Between> conditions, JoinResult.BuilderSettings resultBuilderSettings) { public SortJoin(List<Between> conditions, JoinKind joinKind) {
JoinStrategy.ensureConditionsNotEmpty(conditions); JoinStrategy.ensureConditionsNotEmpty(conditions);
this.resultBuilderSettings = resultBuilderSettings; this.joinKind = joinKind;
Context context = Context.getCurrent(); Context context = Context.getCurrent();
int nConditions = conditions.size(); int nConditions = conditions.size();
@ -35,7 +36,7 @@ public class SortJoin implements JoinStrategy, PluggableJoinStrategy {
} }
} }
private final JoinResult.BuilderSettings resultBuilderSettings; private final JoinKind joinKind;
private final int[] directions; private final int[] directions;
private final Storage<?>[] leftStorages; private final Storage<?>[] leftStorages;
@ -46,13 +47,13 @@ public class SortJoin implements JoinStrategy, PluggableJoinStrategy {
@Override @Override
public JoinResult join(ProblemAggregator problemAggregator) { public JoinResult join(ProblemAggregator problemAggregator) {
Context context = Context.getCurrent(); Context context = Context.getCurrent();
JoinResult.Builder resultBuilder = new JoinResult.Builder(resultBuilderSettings); JoinResult.Builder resultBuilder = new JoinResult.Builder();
int leftRowCount = leftStorages[0].size(); int leftRowCount = leftStorages[0].size();
int rightRowCount = lowerStorages[0].size(); int rightRowCount = lowerStorages[0].size();
if (leftRowCount == 0 || rightRowCount == 0) { if (leftRowCount == 0 || rightRowCount == 0) {
// if one group is completely empty, there will be no matches to report // if one group is completely empty, there will be no matches to report
return resultBuilder.build(); return resultBuilder.buildAndInvalidate();
} }
List<OrderedMultiValueKey> leftKeys = new ArrayList<>(leftRowCount); List<OrderedMultiValueKey> leftKeys = new ArrayList<>(leftRowCount);
for (int i = 0; i < leftRowCount; i++) { for (int i = 0; i < leftRowCount; i++) {
@ -64,13 +65,13 @@ public class SortJoin implements JoinStrategy, PluggableJoinStrategy {
for (int rightRowIx = 0; rightRowIx < rightRowCount; rightRowIx++) { for (int rightRowIx = 0; rightRowIx < rightRowCount; rightRowIx++) {
int matches = addMatchingLeftRows(leftIndex, rightRowIx, resultBuilder); int matches = addMatchingLeftRows(leftIndex, rightRowIx, resultBuilder);
if (resultBuilderSettings.wantsRightUnmatched() && matches == 0) { if (joinKind.wantsRightUnmatched && matches == 0) {
resultBuilder.addUnmatchedRightRow(rightRowIx); resultBuilder.addUnmatchedRightRow(rightRowIx);
} }
context.safepoint(); context.safepoint();
} }
if (resultBuilderSettings.wantsLeftUnmatched()) { if (joinKind.wantsLeftUnmatched) {
for (int leftRowIx = 0; leftRowIx < leftRowCount; leftRowIx++) { for (int leftRowIx = 0; leftRowIx < leftRowCount; leftRowIx++) {
if (!matchedLeftRows.get(leftRowIx)) { if (!matchedLeftRows.get(leftRowIx)) {
resultBuilder.addUnmatchedLeftRow(leftRowIx); resultBuilder.addUnmatchedLeftRow(leftRowIx);
@ -79,7 +80,7 @@ public class SortJoin implements JoinStrategy, PluggableJoinStrategy {
} }
} }
return resultBuilder.build(); return resultBuilder.buildAndInvalidate();
} }
@Override @Override
@ -103,13 +104,13 @@ public class SortJoin implements JoinStrategy, PluggableJoinStrategy {
for (int rightRowIx : rightGroup) { for (int rightRowIx : rightGroup) {
int matches = addMatchingLeftRows(leftIndex, rightRowIx, resultBuilder); int matches = addMatchingLeftRows(leftIndex, rightRowIx, resultBuilder);
if (resultBuilderSettings.wantsRightUnmatched() && matches == 0) { if (joinKind.wantsRightUnmatched && matches == 0) {
resultBuilder.addUnmatchedRightRow(rightRowIx); resultBuilder.addUnmatchedRightRow(rightRowIx);
} }
context.safepoint(); context.safepoint();
} }
if (resultBuilderSettings.wantsLeftUnmatched()) { if (joinKind.wantsLeftUnmatched) {
for (int leftRowIx : leftGroup) { for (int leftRowIx : leftGroup) {
if (!matchedLeftRows.get(leftRowIx)) { if (!matchedLeftRows.get(leftRowIx)) {
resultBuilder.addUnmatchedLeftRow(leftRowIx); resultBuilder.addUnmatchedLeftRow(leftRowIx);
@ -161,10 +162,10 @@ public class SortJoin implements JoinStrategy, PluggableJoinStrategy {
if (isInRange(key, lowerBound, upperBound)) { if (isInRange(key, lowerBound, upperBound)) {
int leftRowIx = key.getRowIndex(); int leftRowIx = key.getRowIndex();
matchCount++; matchCount++;
if (resultBuilderSettings.wantsCommon()) { if (joinKind.wantsCommon) {
resultBuilder.addMatchedRowsPair(leftRowIx, rightRowIx); resultBuilder.addMatchedRowsPair(leftRowIx, rightRowIx);
} }
if (resultBuilderSettings.wantsLeftUnmatched()) { if (joinKind.wantsLeftUnmatched) {
matchedLeftRows.set(leftRowIx); matchedLeftRows.set(leftRowIx);
} }
} }

View File

@ -5,6 +5,7 @@ import org.enso.base.text.TextFoldingStrategy;
import org.enso.table.data.index.MultiValueIndex; import org.enso.table.data.index.MultiValueIndex;
import org.enso.table.data.index.UnorderedMultiValueKey; import org.enso.table.data.index.UnorderedMultiValueKey;
import org.enso.table.data.table.Column; import org.enso.table.data.table.Column;
import org.enso.table.data.table.join.JoinKind;
import org.enso.table.data.table.join.JoinResult; import org.enso.table.data.table.join.JoinResult;
import org.enso.table.data.table.join.JoinStrategy; import org.enso.table.data.table.join.JoinStrategy;
import org.enso.table.data.table.join.PluggableJoinStrategy; import org.enso.table.data.table.join.PluggableJoinStrategy;
@ -24,10 +25,10 @@ public class HashJoin implements JoinStrategy {
public HashJoin( public HashJoin(
List<HashableCondition> conditions, List<HashableCondition> conditions,
PluggableJoinStrategy remainingMatcher, PluggableJoinStrategy remainingMatcher,
JoinResult.BuilderSettings resultBuilderSettings) { JoinKind joinKind) {
JoinStrategy.ensureConditionsNotEmpty(conditions); JoinStrategy.ensureConditionsNotEmpty(conditions);
this.remainingMatcher = remainingMatcher; this.remainingMatcher = remainingMatcher;
this.resultBuilderSettings = resultBuilderSettings; this.joinKind = joinKind;
List<HashEqualityCondition> equalConditions = List<HashEqualityCondition> equalConditions =
conditions.stream().map(HashJoin::makeHashEqualityCondition).toList(); conditions.stream().map(HashJoin::makeHashEqualityCondition).toList();
@ -46,7 +47,7 @@ public class HashJoin implements JoinStrategy {
private final Column[] leftEquals, rightEquals; private final Column[] leftEquals, rightEquals;
private final List<TextFoldingStrategy> textFoldingStrategies; private final List<TextFoldingStrategy> textFoldingStrategies;
private final PluggableJoinStrategy remainingMatcher; private final PluggableJoinStrategy remainingMatcher;
private final JoinResult.BuilderSettings resultBuilderSettings; private final JoinKind joinKind;
@Override @Override
public JoinResult join(ProblemAggregator problemAggregator) { public JoinResult join(ProblemAggregator problemAggregator) {
@ -59,7 +60,7 @@ public class HashJoin implements JoinStrategy {
MultiValueIndex.makeUnorderedIndex( MultiValueIndex.makeUnorderedIndex(
rightEquals, rightEquals[0].getSize(), textFoldingStrategies, problemAggregator); rightEquals, rightEquals[0].getSize(), textFoldingStrategies, problemAggregator);
JoinResult.Builder resultBuilder = new JoinResult.Builder(resultBuilderSettings); JoinResult.Builder resultBuilder = new JoinResult.Builder();
for (var leftEntry : leftIndex.mapping().entrySet()) { for (var leftEntry : leftIndex.mapping().entrySet()) {
UnorderedMultiValueKey leftKey = leftEntry.getKey(); UnorderedMultiValueKey leftKey = leftEntry.getKey();
List<Integer> leftRows = leftEntry.getValue(); List<Integer> leftRows = leftEntry.getValue();
@ -68,7 +69,7 @@ public class HashJoin implements JoinStrategy {
if (rightRows != null) { if (rightRows != null) {
remainingMatcher.joinSubsets(leftRows, rightRows, resultBuilder, problemAggregator); remainingMatcher.joinSubsets(leftRows, rightRows, resultBuilder, problemAggregator);
} else { } else {
if (resultBuilderSettings.wantsLeftUnmatched()) { if (joinKind.wantsLeftUnmatched) {
for (int leftRow : leftRows) { for (int leftRow : leftRows) {
resultBuilder.addUnmatchedLeftRow(leftRow); resultBuilder.addUnmatchedLeftRow(leftRow);
context.safepoint(); context.safepoint();
@ -79,7 +80,7 @@ public class HashJoin implements JoinStrategy {
context.safepoint(); context.safepoint();
} }
if (resultBuilderSettings.wantsRightUnmatched()) { if (joinKind.wantsRightUnmatched) {
for (var rightEntry : rightIndex.mapping().entrySet()) { for (var rightEntry : rightIndex.mapping().entrySet()) {
UnorderedMultiValueKey rightKey = rightEntry.getKey(); UnorderedMultiValueKey rightKey = rightEntry.getKey();
boolean wasCompletelyUnmatched = !leftIndex.contains(rightKey); boolean wasCompletelyUnmatched = !leftIndex.contains(rightKey);
@ -91,7 +92,7 @@ public class HashJoin implements JoinStrategy {
} }
} }
return resultBuilder.build(); return resultBuilder.buildAndInvalidate();
} }
private static HashEqualityCondition makeHashEqualityCondition(HashableCondition eq) { private static HashEqualityCondition makeHashEqualityCondition(HashableCondition eq) {

View File

@ -201,7 +201,7 @@ public class LookupJoin {
@Override @Override
public Column build(int[] orderMask) { public Column build(int[] orderMask) {
assert orderMask != null; assert orderMask != null;
return lookupColumn.applyMask(new OrderMask(orderMask)); return lookupColumn.applyMask(OrderMask.fromArray(orderMask));
} }
} }
} }

View File

@ -71,17 +71,6 @@ public class OrderBuilder {
int[] positions = int[] positions =
IntStream.range(0, size).boxed().sorted(comparator).mapToInt(i -> i).toArray(); IntStream.range(0, size).boxed().sorted(comparator).mapToInt(i -> i).toArray();
return new OrderMask(positions); return OrderMask.fromArray(positions);
}
/**
* Builds an order mask based that will reverse the order of the data being masked.
*
* @param size the length of the data being masked
* @return an order mask that will result in reversing the data it is applied to
*/
public static OrderMask buildReversedMask(int size) {
int[] positions = IntStream.range(0, size).map(i -> size - i - 1).toArray();
return new OrderMask(positions);
} }
} }

View File

@ -153,8 +153,8 @@ create_scenario_antijoin num_rows =
## This is a scenario where we join a very large table with a much smaller table ## This is a scenario where we join a very large table with a much smaller table
to check an optimisation where we only index the smaller of the 2 tables to check an optimisation where we only index the smaller of the 2 tables
create_scenario_large_small_table = create_scenario_large_small_table =
xs = (0.up_to 1000000).map _-> Random.integer 0 99 xs = (0.up_to 10000000).map _-> Random.integer 0 999
ys = (0.up_to 100).to_vector ys = (0.up_to 1000).to_vector
table1 = Table.new [["key", xs]] table1 = Table.new [["key", xs]]
table2 = Table.new [["key", ys]] table2 = Table.new [["key", ys]]
Scenario.Value table1 table2 Scenario.Value table1 table2
@ -228,12 +228,12 @@ collect_benches = Bench.build builder->
r = scenario.table2.join scenario.table1 on="key" join_kind=Join_Kind.Left_Exclusive r = scenario.table2.join scenario.table1 on="key" join_kind=Join_Kind.Left_Exclusive
assert (r.row_count == 1000) assert (r.row_count == 1000)
if extended_tests then group_builder.specify "Join_Large_Table_to_Small_Table" <| group_builder.specify "Join_Large_Table_to_Small_Table" <|
scenario = data.large_small_table scenario = data.large_small_table
r = scenario.table1.join scenario.table2 on="key" r = scenario.table1.join scenario.table2 on="key"
assert (r.row_count == scenario.table1.row_count) assert (r.row_count == scenario.table1.row_count)
if extended_tests then group_builder.specify "Join_Small_Table_to_Large_Table" <| group_builder.specify "Join_Small_Table_to_Large_Table" <|
scenario = data.large_small_table scenario = data.large_small_table
r = scenario.table2.join scenario.table1 on="key" r = scenario.table2.join scenario.table1 on="key"
assert (r.row_count == scenario.table1.row_count) assert (r.row_count == scenario.table1.row_count)