Refactor OrderMask to avoid memory copying (#8863)

Goal of this PR is to refactor the design of OrderMask and avoid copying arrays or lists wherever possible.
We have removed a few legacy functions which were not being used.

On a poor mans benchmark seems to be quicker (13s vs 16s) and memory usage should be lower.
This commit is contained in:
James Dunkerley 2024-01-26 11:16:16 +00:00 committed by GitHub
parent 5dd2dc1c93
commit 0b6db5797c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 213 additions and 183 deletions

View File

@ -32,6 +32,7 @@ from project.Internal.Column_Format import all
from project.Internal.Java_Exports import make_date_builder_adapter, make_string_builder
polyglot java import org.enso.base.Time_Utils
polyglot java import org.enso.table.data.mask.OrderMask
polyglot java import org.enso.table.data.column.operation.cast.CastProblemAggregator
polyglot java import org.enso.table.data.column.storage.Storage as Java_Storage
polyglot java import org.enso.table.data.table.Column as Java_Column
@ -2284,7 +2285,7 @@ type Column
example_reverse = Examples.integer_column.reverse
reverse : Column
reverse self =
mask = OrderBuilder.buildReversedMask self.length
mask = OrderMask.reverse self.length
Column.Value (self.java_column.applyMask mask)
## GROUP Standard.Base.Metadata

View File

@ -78,7 +78,6 @@ polyglot java import org.enso.table.error.NonUniqueLookupKey
polyglot java import org.enso.table.error.NullValuesInKeyColumns
polyglot java import org.enso.table.error.TooManyColumnsException
polyglot java import org.enso.table.error.UnmatchedRow
polyglot java import org.enso.table.operations.OrderBuilder
polyglot java import org.enso.table.parsing.problems.ParseProblemAggregator
## Represents a column-oriented table data structure.
@ -2423,7 +2422,7 @@ type Table
example_reverse = Examples.inventory_table.reverse
reverse : Table
reverse self =
mask = OrderBuilder.buildReversedMask self.row_count
mask = OrderMask.reverse self.row_count
Table.Value <| self.java_table.applyMask mask
## GROUP Standard.Base.Output

View File

@ -12,7 +12,6 @@ polyglot java import java.lang.IllegalArgumentException
polyglot java import java.time.temporal.UnsupportedTemporalTypeException
polyglot java import org.enso.table.data.column.storage.Storage as Java_Storage
polyglot java import org.enso.table.data.table.Column as Java_Column
polyglot java import org.enso.table.operations.OrderBuilder
## PRIVATE
Create a formatter for the specified `Value_Type`.

View File

@ -137,7 +137,7 @@ fan_out_to_rows_and_columns table input_column_id function column_names at_least
Column.from_storage column_name output_storage
# Build the order mask.
order_mask = OrderMask.new (order_mask_positions.to_vector)
order_mask = OrderMask.fromArray (order_mask_positions.to_vector)
## Build the new table, replacing the input column with the new output
columns.

View File

@ -4,8 +4,6 @@ import project.Data.Table.Table
import project.Data.Type.Value_Type.Value_Type
from project.Internal.Fan_Out import all
polyglot java import org.enso.table.data.mask.OrderMask
## PRIVATE
Splits a column of text into a set of new columns.
See `Table.split_to_columns`.

View File

@ -57,15 +57,9 @@ public class IntArrayBuilder {
* <p>After calling this method, the builder is invalidated and cannot be used anymore. Any usage
* of the builder afterwards will result in a {@code NullPointerException}.
*/
public int[] unsafeGetStorageAndInvalidateTheBuilder() {
public int[] unsafeGetResultAndInvalidate() {
int[] tmp = storage;
this.storage = null;
return tmp;
}
public int[] build() {
int[] result = new int[length];
System.arraycopy(storage, 0, result, 0, length);
return result;
}
}

View File

@ -158,9 +158,9 @@ public class CaseFoldedString {
return new CaseFoldedString(
stringBuilder.toString(),
grapheme_mapping.unsafeGetStorageAndInvalidateTheBuilder(),
codeunit_start_mapping.unsafeGetStorageAndInvalidateTheBuilder(),
codeunit_end_mapping.unsafeGetStorageAndInvalidateTheBuilder());
grapheme_mapping.unsafeGetResultAndInvalidate(),
codeunit_start_mapping.unsafeGetResultAndInvalidate(),
codeunit_end_mapping.unsafeGetResultAndInvalidate());
}
/**

View File

@ -203,19 +203,19 @@ public final class BoolStorage extends Storage<Boolean> {
@Override
public BoolStorage applyMask(OrderMask mask) {
Context context = Context.getCurrent();
int[] positions = mask.getPositions();
BitSet newNa = new BitSet();
BitSet newVals = new BitSet();
for (int i = 0; i < positions.length; i++) {
if (positions[i] == Index.NOT_FOUND || isMissing.get(positions[i])) {
for (int i = 0; i < mask.length(); i++) {
int position = mask.get(i);
if (position == Index.NOT_FOUND || isMissing.get(position)) {
newNa.set(i);
} else if (values.get(positions[i])) {
} else if (values.get(position)) {
newVals.set(i);
}
context.safepoint();
}
return new BoolStorage(newVals, newNa, positions.length, negated);
return new BoolStorage(newVals, newNa, mask.length(), negated);
}
@Override

View File

@ -127,18 +127,13 @@ public abstract class SpecializedStorage<T> extends Storage<T> {
@Override
public SpecializedStorage<T> applyMask(OrderMask mask) {
Context context = Context.getCurrent();
int[] positions = mask.getPositions();
T[] newData = newUnderlyingArray(positions.length);
for (int i = 0; i < positions.length; i++) {
if (positions[i] == Index.NOT_FOUND) {
newData[i] = null;
} else {
newData[i] = data[positions[i]];
}
T[] newData = newUnderlyingArray(mask.length());
for (int i = 0; i < mask.length(); i++) {
int position = mask.get(i);
newData[i] = position == Index.NOT_FOUND ? null : data[position];
context.safepoint();
}
return newInstance(newData, positions.length);
return newInstance(newData, newData.length);
}
@Override

View File

@ -80,20 +80,20 @@ public abstract class ComputedLongStorage extends AbstractLongStorage {
@Override
public Storage<Long> applyMask(OrderMask mask) {
int[] positions = mask.getPositions();
long[] newData = new long[positions.length];
long[] newData = new long[mask.length()];
BitSet newMissing = new BitSet();
Context context = Context.getCurrent();
for (int i = 0; i < positions.length; i++) {
if (positions[i] == Index.NOT_FOUND) {
for (int i = 0; i < mask.length(); i++) {
int position = mask.get(i);
if (position == Index.NOT_FOUND) {
newMissing.set(i);
} else {
newData[i] = getItem(positions[i]);
newData[i] = getItem(position);
}
context.safepoint();
}
return new LongStorage(newData, positions.length, newMissing, getType());
return new LongStorage(newData, newData.length, newMissing, getType());
}
@Override

View File

@ -100,15 +100,15 @@ public abstract class ComputedNullableLongStorage extends AbstractLongStorage {
@Override
public Storage<Long> applyMask(OrderMask mask) {
int[] positions = mask.getPositions();
long[] newData = new long[positions.length];
long[] newData = new long[mask.length()];
BitSet newMissing = new BitSet();
Context context = Context.getCurrent();
for (int i = 0; i < positions.length; i++) {
if (positions[i] == Index.NOT_FOUND) {
for (int i = 0; i < mask.length(); i++) {
int position = mask.get(i);
if (position == Index.NOT_FOUND) {
newMissing.set(i);
} else {
Long item = computeItem(positions[i]);
Long item = computeItem(position);
if (item == null) {
newMissing.set(i);
} else {
@ -118,7 +118,7 @@ public abstract class ComputedNullableLongStorage extends AbstractLongStorage {
context.safepoint();
}
return new LongStorage(newData, positions.length, newMissing, getType());
return new LongStorage(newData, newData.length, newMissing, getType());
}
@Override

View File

@ -278,20 +278,20 @@ public final class DoubleStorage extends NumericStorage<Double> implements Doubl
@Override
public Storage<Double> applyMask(OrderMask mask) {
int[] positions = mask.getPositions();
long[] newData = new long[positions.length];
long[] newData = new long[mask.length()];
BitSet newMissing = new BitSet();
Context context = Context.getCurrent();
for (int i = 0; i < positions.length; i++) {
if (positions[i] == Index.NOT_FOUND || isMissing.get(positions[i])) {
for (int i = 0; i < mask.length(); i++) {
int position = mask.get(i);
if (position == Index.NOT_FOUND || isMissing.get(position)) {
newMissing.set(i);
} else {
newData[i] = data[positions[i]];
newData[i] = data[position];
}
context.safepoint();
}
return new DoubleStorage(newData, positions.length, newMissing);
return new DoubleStorage(newData, newData.length, newMissing);
}
@Override

View File

@ -191,20 +191,20 @@ public final class LongStorage extends AbstractLongStorage {
@Override
public Storage<Long> applyMask(OrderMask mask) {
int[] positions = mask.getPositions();
long[] newData = new long[positions.length];
long[] newData = new long[mask.length()];
BitSet newMissing = new BitSet();
Context context = Context.getCurrent();
for (int i = 0; i < positions.length; i++) {
if (positions[i] == Index.NOT_FOUND || isMissing.get(positions[i])) {
for (int i = 0; i < mask.length(); i++) {
int position = mask.get(i);
if (position == Index.NOT_FOUND || isMissing.get(position)) {
newMissing.set(i);
} else {
newData[i] = data[positions[i]];
newData[i] = data[position];
}
context.safepoint();
}
return new LongStorage(newData, positions.length, newMissing, type);
return new LongStorage(newData, newData.length, newMissing, type);
}
@Override

View File

@ -1,69 +1,96 @@
package org.enso.table.data.mask;
import java.util.Arrays;
import java.util.List;
import org.graalvm.polyglot.Context;
import java.util.function.ToIntFunction;
/** Describes a storage reordering operator. */
public class OrderMask {
private final int[] positions;
public interface OrderMask {
int length();
/**
* Creates a new reordering operator, with the specified characteristics. See {@link
* #getPositions()} for a description of the semantics.
*
* @param positions the positions array, as described by {@link #getPositions()}
*/
public OrderMask(int[] positions) {
this.positions = positions;
}
/**
* Describes the reordering that should happen on the applying storage.
* Describes the reordering that should happen on the applying storage at the index.
*
* <p>The resulting storage should contain the {@code positions[i]}-th element of the original
* storage at the i-th position. {@code positions[i]} may be equal to {@link
* storage at the {@code idx}-th position. It may return {@link
* org.enso.table.data.index.Index.NOT_FOUND}, in which case a missing value should be inserted at
* this position.
*/
public int[] getPositions() {
return positions;
int get(int idx);
static OrderMask empty() {
return new OrderMaskFromArray(new int[0], 0);
}
public OrderMask append(OrderMask other) {
int[] result = Arrays.copyOf(positions, positions.length + other.positions.length);
System.arraycopy(other.positions, 0, result, positions.length, other.positions.length);
return new OrderMask(result);
static OrderMask reverse(int size) {
return new OrderMaskReversed(size);
}
public static OrderMask empty() {
return new OrderMask(new int[0]);
static OrderMask fromArray(int[] positions) {
return fromArray(positions, positions.length);
}
public static OrderMask fromList(List<Integer> positions) {
Context context = Context.getCurrent();
int[] result = new int[positions.size()];
for (int i = 0; i < positions.size(); i++) {
result[i] = positions.get(i);
context.safepoint();
static OrderMask fromArray(int[] positions, int length) {
return new OrderMaskFromArray(positions, length);
}
static <T> OrderMask fromObjects(T[] input, ToIntFunction<T> function) {
return new OrderMaskGeneric<>(input, function);
}
class OrderMaskFromArray implements OrderMask {
private final int[] positions;
private final int length;
public OrderMaskFromArray(int[] positions, int length) {
this.positions = positions;
this.length = length;
}
@Override
public int length() {
return length;
}
@Override
public int get(int idx) {
return positions[idx];
}
return new OrderMask(result);
}
public static OrderMask concat(List<OrderMask> masks) {
Context context = Context.getCurrent();
int size = 0;
for (OrderMask mask : masks) {
size += mask.positions.length;
context.safepoint();
class OrderMaskGeneric<T> implements OrderMask {
private final T[] positions;
private final ToIntFunction<T> function;
public OrderMaskGeneric(T[] positions, ToIntFunction<T> function) {
this.positions = positions;
this.function = function;
}
int[] result = new int[size];
int offset = 0;
for (OrderMask mask : masks) {
System.arraycopy(mask.positions, 0, result, offset, mask.positions.length);
offset += mask.positions.length;
context.safepoint();
@Override
public int length() {
return positions.length;
}
@Override
public int get(int idx) {
return function.applyAsInt(positions[idx]);
}
}
class OrderMaskReversed implements OrderMask {
private final int length;
public OrderMaskReversed(int length) {
this.length = length;
}
@Override
public int length() {
return length;
}
@Override
public int get(int idx) {
return length - idx - 1;
}
return new OrderMask(result);
}
}

View File

@ -222,8 +222,7 @@ public class Table {
context.safepoint();
}
Arrays.sort(keys);
int[] positions = Arrays.stream(keys).mapToInt(MultiValueKeyBase::getRowIndex).toArray();
OrderMask mask = new OrderMask(positions);
OrderMask mask = OrderMask.fromObjects(keys, MultiValueKeyBase::getRowIndex);
return this.applyMask(mask);
}

View File

@ -5,16 +5,13 @@ import org.graalvm.polyglot.Context;
public class CrossJoin {
public static JoinResult perform(int leftRowCount, int rightRowCount) {
Context context = Context.getCurrent();
JoinResult.BuilderSettings settings = new JoinResult.BuilderSettings(true, true, true);
JoinResult.Builder resultBuilder =
new JoinResult.Builder(leftRowCount * rightRowCount, settings);
JoinResult.Builder resultBuilder = new JoinResult.Builder(leftRowCount * rightRowCount);
for (int l = 0; l < leftRowCount; ++l) {
for (int r = 0; r < rightRowCount; ++r) {
resultBuilder.addMatchedRowsPair(l, r);
context.safepoint();
}
}
return resultBuilder.build();
return resultBuilder.buildAndInvalidate();
}
}

View File

@ -1,21 +1,20 @@
package org.enso.table.data.table.join;
public enum JoinKind {
INNER,
FULL,
LEFT_OUTER,
RIGHT_OUTER,
LEFT_ANTI,
RIGHT_ANTI;
INNER(true, false, false),
FULL(true, true, true),
LEFT_OUTER(true, true, false),
RIGHT_OUTER(true, false, true),
LEFT_ANTI(false, true, false),
RIGHT_ANTI(false, false, true);
public static JoinResult.BuilderSettings makeSettings(JoinKind joinKind) {
return switch (joinKind) {
case INNER -> new JoinResult.BuilderSettings(true, false, false);
case FULL -> new JoinResult.BuilderSettings(true, true, true);
case LEFT_OUTER -> new JoinResult.BuilderSettings(true, true, false);
case RIGHT_OUTER -> new JoinResult.BuilderSettings(true, false, true);
case LEFT_ANTI -> new JoinResult.BuilderSettings(false, true, false);
case RIGHT_ANTI -> new JoinResult.BuilderSettings(false, false, true);
};
public final boolean wantsCommon;
public final boolean wantsLeftUnmatched;
public final boolean wantsRightUnmatched;
private JoinKind(boolean wantsCommon, boolean wantsLeftUnmatched, boolean wantsRightUnmatched) {
this.wantsCommon = wantsCommon;
this.wantsLeftUnmatched = wantsLeftUnmatched;
this.wantsRightUnmatched = wantsRightUnmatched;
}
}

View File

@ -3,33 +3,39 @@ package org.enso.table.data.table.join;
import org.enso.base.arrays.IntArrayBuilder;
import org.enso.table.data.mask.OrderMask;
public record JoinResult(int[] matchedRowsLeftIndices, int[] matchedRowsRightIndices) {
public class JoinResult {
private final int length;
private final int[] leftIndices;
private final int[] rightIndices;
public JoinResult(int[] leftIndices, int[] rightIndices, int length) {
this.length = length;
this.leftIndices = leftIndices;
this.rightIndices = rightIndices;
}
// ** Represents a pair of indices of matched rows. -1 means an unmatched row.*/
public record RowPair(int leftIndex, int rightIndex) {}
public OrderMask getLeftOrderMask() {
return new OrderMask(matchedRowsLeftIndices);
return OrderMask.fromArray(leftIndices, length);
}
public OrderMask getRightOrderMask() {
return new OrderMask(matchedRowsRightIndices);
return OrderMask.fromArray(rightIndices, length);
}
public record BuilderSettings(
boolean wantsCommon, boolean wantsLeftUnmatched, boolean wantsRightUnmatched) {}
public static class Builder {
IntArrayBuilder leftIndices;
IntArrayBuilder rightIndices;
final BuilderSettings settings;
public Builder(int initialCapacity, BuilderSettings settings) {
public Builder(int initialCapacity) {
leftIndices = new IntArrayBuilder(initialCapacity);
rightIndices = new IntArrayBuilder(initialCapacity);
this.settings = settings;
}
public Builder(BuilderSettings settings) {
this(128, settings);
public Builder() {
this(128);
}
public void addMatchedRowsPair(int leftIndex, int rightIndex) {
@ -47,8 +53,22 @@ public record JoinResult(int[] matchedRowsLeftIndices, int[] matchedRowsRightInd
rightIndices.add(rightIndex);
}
public JoinResult build() {
return new JoinResult(leftIndices.build(), rightIndices.build());
/**
* Returns the result of the builder.
*
* <p>This method avoids copying for performance. After calling this method, the builder is
* invalidated and cannot be used anymore. Any usage of the builder afterwards will result in a
* {@code NullPointerException}.
*/
public JoinResult buildAndInvalidate() {
var left = leftIndices;
var right = rightIndices;
leftIndices = null;
rightIndices = null;
return new JoinResult(
left.unsafeGetResultAndInvalidate(),
right.unsafeGetResultAndInvalidate(),
left.getLength());
}
}
}

View File

@ -15,8 +15,6 @@ public interface JoinStrategy {
static JoinStrategy createStrategy(List<JoinCondition> conditions, JoinKind joinKind) {
ensureConditionsNotEmpty(conditions);
JoinResult.BuilderSettings builderSettings = JoinKind.makeSettings(joinKind);
List<HashableCondition> hashableConditions =
conditions.stream()
.filter(c -> c instanceof HashableCondition)
@ -31,12 +29,14 @@ public interface JoinStrategy {
if (hashableConditions.isEmpty()) {
assert !betweenConditions.isEmpty();
return new SortJoin(betweenConditions, builderSettings);
return new SortJoin(betweenConditions, joinKind);
} else if (betweenConditions.isEmpty()) {
return new HashJoin(hashableConditions, new MatchAllStrategy(), builderSettings);
} else {
return new HashJoin(
hashableConditions, new SortJoin(betweenConditions, builderSettings), builderSettings);
hashableConditions,
joinKind.wantsCommon ? new MatchAllStrategy() : new NoOpStrategy(),
joinKind);
} else {
return new HashJoin(hashableConditions, new SortJoin(betweenConditions, joinKind), joinKind);
}
}

View File

@ -15,10 +15,6 @@ public class MatchAllStrategy implements PluggableJoinStrategy {
List<Integer> rightGroup,
JoinResult.Builder resultBuilder,
ProblemAggregator problemAggregator) {
if (!resultBuilder.settings.wantsCommon()) {
return;
}
Context context = Context.getCurrent();
for (var leftRow : leftGroup) {
for (var rightRow : rightGroup) {

View File

@ -0,0 +1,15 @@
package org.enso.table.data.table.join;
import java.util.List;
import org.enso.table.problems.ProblemAggregator;
public class NoOpStrategy implements PluggableJoinStrategy {
@Override
public void joinSubsets(
List<Integer> leftGroup,
List<Integer> rightGroup,
JoinResult.Builder resultBuilder,
ProblemAggregator problemAggregator) {
return;
}
}

View File

@ -7,6 +7,7 @@ import java.util.List;
import org.enso.base.ObjectComparator;
import org.enso.table.data.column.storage.Storage;
import org.enso.table.data.index.OrderedMultiValueKey;
import org.enso.table.data.table.join.JoinKind;
import org.enso.table.data.table.join.JoinResult;
import org.enso.table.data.table.join.JoinStrategy;
import org.enso.table.data.table.join.PluggableJoinStrategy;
@ -16,9 +17,9 @@ import org.graalvm.polyglot.Context;
public class SortJoin implements JoinStrategy, PluggableJoinStrategy {
public SortJoin(List<Between> conditions, JoinResult.BuilderSettings resultBuilderSettings) {
public SortJoin(List<Between> conditions, JoinKind joinKind) {
JoinStrategy.ensureConditionsNotEmpty(conditions);
this.resultBuilderSettings = resultBuilderSettings;
this.joinKind = joinKind;
Context context = Context.getCurrent();
int nConditions = conditions.size();
@ -35,7 +36,7 @@ public class SortJoin implements JoinStrategy, PluggableJoinStrategy {
}
}
private final JoinResult.BuilderSettings resultBuilderSettings;
private final JoinKind joinKind;
private final int[] directions;
private final Storage<?>[] leftStorages;
@ -46,13 +47,13 @@ public class SortJoin implements JoinStrategy, PluggableJoinStrategy {
@Override
public JoinResult join(ProblemAggregator problemAggregator) {
Context context = Context.getCurrent();
JoinResult.Builder resultBuilder = new JoinResult.Builder(resultBuilderSettings);
JoinResult.Builder resultBuilder = new JoinResult.Builder();
int leftRowCount = leftStorages[0].size();
int rightRowCount = lowerStorages[0].size();
if (leftRowCount == 0 || rightRowCount == 0) {
// if one group is completely empty, there will be no matches to report
return resultBuilder.build();
return resultBuilder.buildAndInvalidate();
}
List<OrderedMultiValueKey> leftKeys = new ArrayList<>(leftRowCount);
for (int i = 0; i < leftRowCount; i++) {
@ -64,13 +65,13 @@ public class SortJoin implements JoinStrategy, PluggableJoinStrategy {
for (int rightRowIx = 0; rightRowIx < rightRowCount; rightRowIx++) {
int matches = addMatchingLeftRows(leftIndex, rightRowIx, resultBuilder);
if (resultBuilderSettings.wantsRightUnmatched() && matches == 0) {
if (joinKind.wantsRightUnmatched && matches == 0) {
resultBuilder.addUnmatchedRightRow(rightRowIx);
}
context.safepoint();
}
if (resultBuilderSettings.wantsLeftUnmatched()) {
if (joinKind.wantsLeftUnmatched) {
for (int leftRowIx = 0; leftRowIx < leftRowCount; leftRowIx++) {
if (!matchedLeftRows.get(leftRowIx)) {
resultBuilder.addUnmatchedLeftRow(leftRowIx);
@ -79,7 +80,7 @@ public class SortJoin implements JoinStrategy, PluggableJoinStrategy {
}
}
return resultBuilder.build();
return resultBuilder.buildAndInvalidate();
}
@Override
@ -103,13 +104,13 @@ public class SortJoin implements JoinStrategy, PluggableJoinStrategy {
for (int rightRowIx : rightGroup) {
int matches = addMatchingLeftRows(leftIndex, rightRowIx, resultBuilder);
if (resultBuilderSettings.wantsRightUnmatched() && matches == 0) {
if (joinKind.wantsRightUnmatched && matches == 0) {
resultBuilder.addUnmatchedRightRow(rightRowIx);
}
context.safepoint();
}
if (resultBuilderSettings.wantsLeftUnmatched()) {
if (joinKind.wantsLeftUnmatched) {
for (int leftRowIx : leftGroup) {
if (!matchedLeftRows.get(leftRowIx)) {
resultBuilder.addUnmatchedLeftRow(leftRowIx);
@ -161,10 +162,10 @@ public class SortJoin implements JoinStrategy, PluggableJoinStrategy {
if (isInRange(key, lowerBound, upperBound)) {
int leftRowIx = key.getRowIndex();
matchCount++;
if (resultBuilderSettings.wantsCommon()) {
if (joinKind.wantsCommon) {
resultBuilder.addMatchedRowsPair(leftRowIx, rightRowIx);
}
if (resultBuilderSettings.wantsLeftUnmatched()) {
if (joinKind.wantsLeftUnmatched) {
matchedLeftRows.set(leftRowIx);
}
}

View File

@ -5,6 +5,7 @@ import org.enso.base.text.TextFoldingStrategy;
import org.enso.table.data.index.MultiValueIndex;
import org.enso.table.data.index.UnorderedMultiValueKey;
import org.enso.table.data.table.Column;
import org.enso.table.data.table.join.JoinKind;
import org.enso.table.data.table.join.JoinResult;
import org.enso.table.data.table.join.JoinStrategy;
import org.enso.table.data.table.join.PluggableJoinStrategy;
@ -24,10 +25,10 @@ public class HashJoin implements JoinStrategy {
public HashJoin(
List<HashableCondition> conditions,
PluggableJoinStrategy remainingMatcher,
JoinResult.BuilderSettings resultBuilderSettings) {
JoinKind joinKind) {
JoinStrategy.ensureConditionsNotEmpty(conditions);
this.remainingMatcher = remainingMatcher;
this.resultBuilderSettings = resultBuilderSettings;
this.joinKind = joinKind;
List<HashEqualityCondition> equalConditions =
conditions.stream().map(HashJoin::makeHashEqualityCondition).toList();
@ -46,7 +47,7 @@ public class HashJoin implements JoinStrategy {
private final Column[] leftEquals, rightEquals;
private final List<TextFoldingStrategy> textFoldingStrategies;
private final PluggableJoinStrategy remainingMatcher;
private final JoinResult.BuilderSettings resultBuilderSettings;
private final JoinKind joinKind;
@Override
public JoinResult join(ProblemAggregator problemAggregator) {
@ -59,7 +60,7 @@ public class HashJoin implements JoinStrategy {
MultiValueIndex.makeUnorderedIndex(
rightEquals, rightEquals[0].getSize(), textFoldingStrategies, problemAggregator);
JoinResult.Builder resultBuilder = new JoinResult.Builder(resultBuilderSettings);
JoinResult.Builder resultBuilder = new JoinResult.Builder();
for (var leftEntry : leftIndex.mapping().entrySet()) {
UnorderedMultiValueKey leftKey = leftEntry.getKey();
List<Integer> leftRows = leftEntry.getValue();
@ -68,7 +69,7 @@ public class HashJoin implements JoinStrategy {
if (rightRows != null) {
remainingMatcher.joinSubsets(leftRows, rightRows, resultBuilder, problemAggregator);
} else {
if (resultBuilderSettings.wantsLeftUnmatched()) {
if (joinKind.wantsLeftUnmatched) {
for (int leftRow : leftRows) {
resultBuilder.addUnmatchedLeftRow(leftRow);
context.safepoint();
@ -79,7 +80,7 @@ public class HashJoin implements JoinStrategy {
context.safepoint();
}
if (resultBuilderSettings.wantsRightUnmatched()) {
if (joinKind.wantsRightUnmatched) {
for (var rightEntry : rightIndex.mapping().entrySet()) {
UnorderedMultiValueKey rightKey = rightEntry.getKey();
boolean wasCompletelyUnmatched = !leftIndex.contains(rightKey);
@ -91,7 +92,7 @@ public class HashJoin implements JoinStrategy {
}
}
return resultBuilder.build();
return resultBuilder.buildAndInvalidate();
}
private static HashEqualityCondition makeHashEqualityCondition(HashableCondition eq) {

View File

@ -201,7 +201,7 @@ public class LookupJoin {
@Override
public Column build(int[] orderMask) {
assert orderMask != null;
return lookupColumn.applyMask(new OrderMask(orderMask));
return lookupColumn.applyMask(OrderMask.fromArray(orderMask));
}
}
}

View File

@ -71,17 +71,6 @@ public class OrderBuilder {
int[] positions =
IntStream.range(0, size).boxed().sorted(comparator).mapToInt(i -> i).toArray();
return new OrderMask(positions);
}
/**
* Builds an order mask based that will reverse the order of the data being masked.
*
* @param size the length of the data being masked
* @return an order mask that will result in reversing the data it is applied to
*/
public static OrderMask buildReversedMask(int size) {
int[] positions = IntStream.range(0, size).map(i -> size - i - 1).toArray();
return new OrderMask(positions);
return OrderMask.fromArray(positions);
}
}

View File

@ -153,8 +153,8 @@ create_scenario_antijoin num_rows =
## This is a scenario where we join a very large table with a much smaller table
to check an optimisation where we only index the smaller of the 2 tables
create_scenario_large_small_table =
xs = (0.up_to 1000000).map _-> Random.integer 0 99
ys = (0.up_to 100).to_vector
xs = (0.up_to 10000000).map _-> Random.integer 0 999
ys = (0.up_to 1000).to_vector
table1 = Table.new [["key", xs]]
table2 = Table.new [["key", ys]]
Scenario.Value table1 table2
@ -228,12 +228,12 @@ collect_benches = Bench.build builder->
r = scenario.table2.join scenario.table1 on="key" join_kind=Join_Kind.Left_Exclusive
assert (r.row_count == 1000)
if extended_tests then group_builder.specify "Join_Large_Table_to_Small_Table" <|
group_builder.specify "Join_Large_Table_to_Small_Table" <|
scenario = data.large_small_table
r = scenario.table1.join scenario.table2 on="key"
assert (r.row_count == scenario.table1.row_count)
if extended_tests then group_builder.specify "Join_Small_Table_to_Large_Table" <|
group_builder.specify "Join_Small_Table_to_Large_Table" <|
scenario = data.large_small_table
r = scenario.table2.join scenario.table1 on="key"
assert (r.row_count == scenario.table1.row_count)