mirror of
https://github.com/enso-org/enso.git
synced 2024-12-23 12:42:16 +03:00
Refactor OrderMask to avoid memory copying (#8863)
Goal of this PR is to refactor the design of OrderMask and avoid copying arrays or lists wherever possible. We have removed a few legacy functions which were not being used. On a poor mans benchmark seems to be quicker (13s vs 16s) and memory usage should be lower.
This commit is contained in:
parent
5dd2dc1c93
commit
0b6db5797c
@ -32,6 +32,7 @@ from project.Internal.Column_Format import all
|
||||
from project.Internal.Java_Exports import make_date_builder_adapter, make_string_builder
|
||||
|
||||
polyglot java import org.enso.base.Time_Utils
|
||||
polyglot java import org.enso.table.data.mask.OrderMask
|
||||
polyglot java import org.enso.table.data.column.operation.cast.CastProblemAggregator
|
||||
polyglot java import org.enso.table.data.column.storage.Storage as Java_Storage
|
||||
polyglot java import org.enso.table.data.table.Column as Java_Column
|
||||
@ -2284,7 +2285,7 @@ type Column
|
||||
example_reverse = Examples.integer_column.reverse
|
||||
reverse : Column
|
||||
reverse self =
|
||||
mask = OrderBuilder.buildReversedMask self.length
|
||||
mask = OrderMask.reverse self.length
|
||||
Column.Value (self.java_column.applyMask mask)
|
||||
|
||||
## GROUP Standard.Base.Metadata
|
||||
|
@ -78,7 +78,6 @@ polyglot java import org.enso.table.error.NonUniqueLookupKey
|
||||
polyglot java import org.enso.table.error.NullValuesInKeyColumns
|
||||
polyglot java import org.enso.table.error.TooManyColumnsException
|
||||
polyglot java import org.enso.table.error.UnmatchedRow
|
||||
polyglot java import org.enso.table.operations.OrderBuilder
|
||||
polyglot java import org.enso.table.parsing.problems.ParseProblemAggregator
|
||||
|
||||
## Represents a column-oriented table data structure.
|
||||
@ -2423,7 +2422,7 @@ type Table
|
||||
example_reverse = Examples.inventory_table.reverse
|
||||
reverse : Table
|
||||
reverse self =
|
||||
mask = OrderBuilder.buildReversedMask self.row_count
|
||||
mask = OrderMask.reverse self.row_count
|
||||
Table.Value <| self.java_table.applyMask mask
|
||||
|
||||
## GROUP Standard.Base.Output
|
||||
|
@ -12,7 +12,6 @@ polyglot java import java.lang.IllegalArgumentException
|
||||
polyglot java import java.time.temporal.UnsupportedTemporalTypeException
|
||||
polyglot java import org.enso.table.data.column.storage.Storage as Java_Storage
|
||||
polyglot java import org.enso.table.data.table.Column as Java_Column
|
||||
polyglot java import org.enso.table.operations.OrderBuilder
|
||||
|
||||
## PRIVATE
|
||||
Create a formatter for the specified `Value_Type`.
|
||||
|
@ -137,7 +137,7 @@ fan_out_to_rows_and_columns table input_column_id function column_names at_least
|
||||
Column.from_storage column_name output_storage
|
||||
|
||||
# Build the order mask.
|
||||
order_mask = OrderMask.new (order_mask_positions.to_vector)
|
||||
order_mask = OrderMask.fromArray (order_mask_positions.to_vector)
|
||||
|
||||
## Build the new table, replacing the input column with the new output
|
||||
columns.
|
||||
|
@ -4,8 +4,6 @@ import project.Data.Table.Table
|
||||
import project.Data.Type.Value_Type.Value_Type
|
||||
from project.Internal.Fan_Out import all
|
||||
|
||||
polyglot java import org.enso.table.data.mask.OrderMask
|
||||
|
||||
## PRIVATE
|
||||
Splits a column of text into a set of new columns.
|
||||
See `Table.split_to_columns`.
|
||||
|
@ -57,15 +57,9 @@ public class IntArrayBuilder {
|
||||
* <p>After calling this method, the builder is invalidated and cannot be used anymore. Any usage
|
||||
* of the builder afterwards will result in a {@code NullPointerException}.
|
||||
*/
|
||||
public int[] unsafeGetStorageAndInvalidateTheBuilder() {
|
||||
public int[] unsafeGetResultAndInvalidate() {
|
||||
int[] tmp = storage;
|
||||
this.storage = null;
|
||||
return tmp;
|
||||
}
|
||||
|
||||
public int[] build() {
|
||||
int[] result = new int[length];
|
||||
System.arraycopy(storage, 0, result, 0, length);
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
@ -158,9 +158,9 @@ public class CaseFoldedString {
|
||||
|
||||
return new CaseFoldedString(
|
||||
stringBuilder.toString(),
|
||||
grapheme_mapping.unsafeGetStorageAndInvalidateTheBuilder(),
|
||||
codeunit_start_mapping.unsafeGetStorageAndInvalidateTheBuilder(),
|
||||
codeunit_end_mapping.unsafeGetStorageAndInvalidateTheBuilder());
|
||||
grapheme_mapping.unsafeGetResultAndInvalidate(),
|
||||
codeunit_start_mapping.unsafeGetResultAndInvalidate(),
|
||||
codeunit_end_mapping.unsafeGetResultAndInvalidate());
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -203,19 +203,19 @@ public final class BoolStorage extends Storage<Boolean> {
|
||||
@Override
|
||||
public BoolStorage applyMask(OrderMask mask) {
|
||||
Context context = Context.getCurrent();
|
||||
int[] positions = mask.getPositions();
|
||||
BitSet newNa = new BitSet();
|
||||
BitSet newVals = new BitSet();
|
||||
for (int i = 0; i < positions.length; i++) {
|
||||
if (positions[i] == Index.NOT_FOUND || isMissing.get(positions[i])) {
|
||||
for (int i = 0; i < mask.length(); i++) {
|
||||
int position = mask.get(i);
|
||||
if (position == Index.NOT_FOUND || isMissing.get(position)) {
|
||||
newNa.set(i);
|
||||
} else if (values.get(positions[i])) {
|
||||
} else if (values.get(position)) {
|
||||
newVals.set(i);
|
||||
}
|
||||
|
||||
context.safepoint();
|
||||
}
|
||||
return new BoolStorage(newVals, newNa, positions.length, negated);
|
||||
return new BoolStorage(newVals, newNa, mask.length(), negated);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -127,18 +127,13 @@ public abstract class SpecializedStorage<T> extends Storage<T> {
|
||||
@Override
|
||||
public SpecializedStorage<T> applyMask(OrderMask mask) {
|
||||
Context context = Context.getCurrent();
|
||||
int[] positions = mask.getPositions();
|
||||
T[] newData = newUnderlyingArray(positions.length);
|
||||
for (int i = 0; i < positions.length; i++) {
|
||||
if (positions[i] == Index.NOT_FOUND) {
|
||||
newData[i] = null;
|
||||
} else {
|
||||
newData[i] = data[positions[i]];
|
||||
}
|
||||
|
||||
T[] newData = newUnderlyingArray(mask.length());
|
||||
for (int i = 0; i < mask.length(); i++) {
|
||||
int position = mask.get(i);
|
||||
newData[i] = position == Index.NOT_FOUND ? null : data[position];
|
||||
context.safepoint();
|
||||
}
|
||||
return newInstance(newData, positions.length);
|
||||
return newInstance(newData, newData.length);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -80,20 +80,20 @@ public abstract class ComputedLongStorage extends AbstractLongStorage {
|
||||
|
||||
@Override
|
||||
public Storage<Long> applyMask(OrderMask mask) {
|
||||
int[] positions = mask.getPositions();
|
||||
long[] newData = new long[positions.length];
|
||||
long[] newData = new long[mask.length()];
|
||||
BitSet newMissing = new BitSet();
|
||||
Context context = Context.getCurrent();
|
||||
for (int i = 0; i < positions.length; i++) {
|
||||
if (positions[i] == Index.NOT_FOUND) {
|
||||
for (int i = 0; i < mask.length(); i++) {
|
||||
int position = mask.get(i);
|
||||
if (position == Index.NOT_FOUND) {
|
||||
newMissing.set(i);
|
||||
} else {
|
||||
newData[i] = getItem(positions[i]);
|
||||
newData[i] = getItem(position);
|
||||
}
|
||||
|
||||
context.safepoint();
|
||||
}
|
||||
return new LongStorage(newData, positions.length, newMissing, getType());
|
||||
return new LongStorage(newData, newData.length, newMissing, getType());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -100,15 +100,15 @@ public abstract class ComputedNullableLongStorage extends AbstractLongStorage {
|
||||
|
||||
@Override
|
||||
public Storage<Long> applyMask(OrderMask mask) {
|
||||
int[] positions = mask.getPositions();
|
||||
long[] newData = new long[positions.length];
|
||||
long[] newData = new long[mask.length()];
|
||||
BitSet newMissing = new BitSet();
|
||||
Context context = Context.getCurrent();
|
||||
for (int i = 0; i < positions.length; i++) {
|
||||
if (positions[i] == Index.NOT_FOUND) {
|
||||
for (int i = 0; i < mask.length(); i++) {
|
||||
int position = mask.get(i);
|
||||
if (position == Index.NOT_FOUND) {
|
||||
newMissing.set(i);
|
||||
} else {
|
||||
Long item = computeItem(positions[i]);
|
||||
Long item = computeItem(position);
|
||||
if (item == null) {
|
||||
newMissing.set(i);
|
||||
} else {
|
||||
@ -118,7 +118,7 @@ public abstract class ComputedNullableLongStorage extends AbstractLongStorage {
|
||||
|
||||
context.safepoint();
|
||||
}
|
||||
return new LongStorage(newData, positions.length, newMissing, getType());
|
||||
return new LongStorage(newData, newData.length, newMissing, getType());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -278,20 +278,20 @@ public final class DoubleStorage extends NumericStorage<Double> implements Doubl
|
||||
|
||||
@Override
|
||||
public Storage<Double> applyMask(OrderMask mask) {
|
||||
int[] positions = mask.getPositions();
|
||||
long[] newData = new long[positions.length];
|
||||
long[] newData = new long[mask.length()];
|
||||
BitSet newMissing = new BitSet();
|
||||
Context context = Context.getCurrent();
|
||||
for (int i = 0; i < positions.length; i++) {
|
||||
if (positions[i] == Index.NOT_FOUND || isMissing.get(positions[i])) {
|
||||
for (int i = 0; i < mask.length(); i++) {
|
||||
int position = mask.get(i);
|
||||
if (position == Index.NOT_FOUND || isMissing.get(position)) {
|
||||
newMissing.set(i);
|
||||
} else {
|
||||
newData[i] = data[positions[i]];
|
||||
newData[i] = data[position];
|
||||
}
|
||||
|
||||
context.safepoint();
|
||||
}
|
||||
return new DoubleStorage(newData, positions.length, newMissing);
|
||||
return new DoubleStorage(newData, newData.length, newMissing);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -191,20 +191,20 @@ public final class LongStorage extends AbstractLongStorage {
|
||||
|
||||
@Override
|
||||
public Storage<Long> applyMask(OrderMask mask) {
|
||||
int[] positions = mask.getPositions();
|
||||
long[] newData = new long[positions.length];
|
||||
long[] newData = new long[mask.length()];
|
||||
BitSet newMissing = new BitSet();
|
||||
Context context = Context.getCurrent();
|
||||
for (int i = 0; i < positions.length; i++) {
|
||||
if (positions[i] == Index.NOT_FOUND || isMissing.get(positions[i])) {
|
||||
for (int i = 0; i < mask.length(); i++) {
|
||||
int position = mask.get(i);
|
||||
if (position == Index.NOT_FOUND || isMissing.get(position)) {
|
||||
newMissing.set(i);
|
||||
} else {
|
||||
newData[i] = data[positions[i]];
|
||||
newData[i] = data[position];
|
||||
}
|
||||
|
||||
context.safepoint();
|
||||
}
|
||||
return new LongStorage(newData, positions.length, newMissing, type);
|
||||
return new LongStorage(newData, newData.length, newMissing, type);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -1,69 +1,96 @@
|
||||
package org.enso.table.data.mask;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import org.graalvm.polyglot.Context;
|
||||
import java.util.function.ToIntFunction;
|
||||
|
||||
/** Describes a storage reordering operator. */
|
||||
public class OrderMask {
|
||||
private final int[] positions;
|
||||
public interface OrderMask {
|
||||
int length();
|
||||
|
||||
/**
|
||||
* Creates a new reordering operator, with the specified characteristics. See {@link
|
||||
* #getPositions()} for a description of the semantics.
|
||||
*
|
||||
* @param positions the positions array, as described by {@link #getPositions()}
|
||||
*/
|
||||
public OrderMask(int[] positions) {
|
||||
this.positions = positions;
|
||||
}
|
||||
|
||||
/**
|
||||
* Describes the reordering that should happen on the applying storage.
|
||||
* Describes the reordering that should happen on the applying storage at the index.
|
||||
*
|
||||
* <p>The resulting storage should contain the {@code positions[i]}-th element of the original
|
||||
* storage at the i-th position. {@code positions[i]} may be equal to {@link
|
||||
* storage at the {@code idx}-th position. It may return {@link
|
||||
* org.enso.table.data.index.Index.NOT_FOUND}, in which case a missing value should be inserted at
|
||||
* this position.
|
||||
*/
|
||||
public int[] getPositions() {
|
||||
return positions;
|
||||
int get(int idx);
|
||||
|
||||
static OrderMask empty() {
|
||||
return new OrderMaskFromArray(new int[0], 0);
|
||||
}
|
||||
|
||||
public OrderMask append(OrderMask other) {
|
||||
int[] result = Arrays.copyOf(positions, positions.length + other.positions.length);
|
||||
System.arraycopy(other.positions, 0, result, positions.length, other.positions.length);
|
||||
return new OrderMask(result);
|
||||
static OrderMask reverse(int size) {
|
||||
return new OrderMaskReversed(size);
|
||||
}
|
||||
|
||||
public static OrderMask empty() {
|
||||
return new OrderMask(new int[0]);
|
||||
static OrderMask fromArray(int[] positions) {
|
||||
return fromArray(positions, positions.length);
|
||||
}
|
||||
|
||||
public static OrderMask fromList(List<Integer> positions) {
|
||||
Context context = Context.getCurrent();
|
||||
int[] result = new int[positions.size()];
|
||||
for (int i = 0; i < positions.size(); i++) {
|
||||
result[i] = positions.get(i);
|
||||
context.safepoint();
|
||||
static OrderMask fromArray(int[] positions, int length) {
|
||||
return new OrderMaskFromArray(positions, length);
|
||||
}
|
||||
|
||||
static <T> OrderMask fromObjects(T[] input, ToIntFunction<T> function) {
|
||||
return new OrderMaskGeneric<>(input, function);
|
||||
}
|
||||
|
||||
class OrderMaskFromArray implements OrderMask {
|
||||
private final int[] positions;
|
||||
private final int length;
|
||||
|
||||
public OrderMaskFromArray(int[] positions, int length) {
|
||||
this.positions = positions;
|
||||
this.length = length;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int length() {
|
||||
return length;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int get(int idx) {
|
||||
return positions[idx];
|
||||
}
|
||||
return new OrderMask(result);
|
||||
}
|
||||
|
||||
public static OrderMask concat(List<OrderMask> masks) {
|
||||
Context context = Context.getCurrent();
|
||||
int size = 0;
|
||||
for (OrderMask mask : masks) {
|
||||
size += mask.positions.length;
|
||||
context.safepoint();
|
||||
class OrderMaskGeneric<T> implements OrderMask {
|
||||
private final T[] positions;
|
||||
private final ToIntFunction<T> function;
|
||||
|
||||
public OrderMaskGeneric(T[] positions, ToIntFunction<T> function) {
|
||||
this.positions = positions;
|
||||
this.function = function;
|
||||
}
|
||||
int[] result = new int[size];
|
||||
int offset = 0;
|
||||
for (OrderMask mask : masks) {
|
||||
System.arraycopy(mask.positions, 0, result, offset, mask.positions.length);
|
||||
offset += mask.positions.length;
|
||||
context.safepoint();
|
||||
|
||||
@Override
|
||||
public int length() {
|
||||
return positions.length;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int get(int idx) {
|
||||
return function.applyAsInt(positions[idx]);
|
||||
}
|
||||
}
|
||||
|
||||
class OrderMaskReversed implements OrderMask {
|
||||
private final int length;
|
||||
|
||||
public OrderMaskReversed(int length) {
|
||||
this.length = length;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int length() {
|
||||
return length;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int get(int idx) {
|
||||
return length - idx - 1;
|
||||
}
|
||||
return new OrderMask(result);
|
||||
}
|
||||
}
|
||||
|
@ -222,8 +222,7 @@ public class Table {
|
||||
context.safepoint();
|
||||
}
|
||||
Arrays.sort(keys);
|
||||
int[] positions = Arrays.stream(keys).mapToInt(MultiValueKeyBase::getRowIndex).toArray();
|
||||
OrderMask mask = new OrderMask(positions);
|
||||
OrderMask mask = OrderMask.fromObjects(keys, MultiValueKeyBase::getRowIndex);
|
||||
return this.applyMask(mask);
|
||||
}
|
||||
|
||||
|
@ -5,16 +5,13 @@ import org.graalvm.polyglot.Context;
|
||||
public class CrossJoin {
|
||||
public static JoinResult perform(int leftRowCount, int rightRowCount) {
|
||||
Context context = Context.getCurrent();
|
||||
JoinResult.BuilderSettings settings = new JoinResult.BuilderSettings(true, true, true);
|
||||
JoinResult.Builder resultBuilder =
|
||||
new JoinResult.Builder(leftRowCount * rightRowCount, settings);
|
||||
JoinResult.Builder resultBuilder = new JoinResult.Builder(leftRowCount * rightRowCount);
|
||||
for (int l = 0; l < leftRowCount; ++l) {
|
||||
for (int r = 0; r < rightRowCount; ++r) {
|
||||
resultBuilder.addMatchedRowsPair(l, r);
|
||||
context.safepoint();
|
||||
}
|
||||
}
|
||||
|
||||
return resultBuilder.build();
|
||||
return resultBuilder.buildAndInvalidate();
|
||||
}
|
||||
}
|
||||
|
@ -1,21 +1,20 @@
|
||||
package org.enso.table.data.table.join;
|
||||
|
||||
public enum JoinKind {
|
||||
INNER,
|
||||
FULL,
|
||||
LEFT_OUTER,
|
||||
RIGHT_OUTER,
|
||||
LEFT_ANTI,
|
||||
RIGHT_ANTI;
|
||||
INNER(true, false, false),
|
||||
FULL(true, true, true),
|
||||
LEFT_OUTER(true, true, false),
|
||||
RIGHT_OUTER(true, false, true),
|
||||
LEFT_ANTI(false, true, false),
|
||||
RIGHT_ANTI(false, false, true);
|
||||
|
||||
public static JoinResult.BuilderSettings makeSettings(JoinKind joinKind) {
|
||||
return switch (joinKind) {
|
||||
case INNER -> new JoinResult.BuilderSettings(true, false, false);
|
||||
case FULL -> new JoinResult.BuilderSettings(true, true, true);
|
||||
case LEFT_OUTER -> new JoinResult.BuilderSettings(true, true, false);
|
||||
case RIGHT_OUTER -> new JoinResult.BuilderSettings(true, false, true);
|
||||
case LEFT_ANTI -> new JoinResult.BuilderSettings(false, true, false);
|
||||
case RIGHT_ANTI -> new JoinResult.BuilderSettings(false, false, true);
|
||||
};
|
||||
public final boolean wantsCommon;
|
||||
public final boolean wantsLeftUnmatched;
|
||||
public final boolean wantsRightUnmatched;
|
||||
|
||||
private JoinKind(boolean wantsCommon, boolean wantsLeftUnmatched, boolean wantsRightUnmatched) {
|
||||
this.wantsCommon = wantsCommon;
|
||||
this.wantsLeftUnmatched = wantsLeftUnmatched;
|
||||
this.wantsRightUnmatched = wantsRightUnmatched;
|
||||
}
|
||||
}
|
||||
|
@ -3,33 +3,39 @@ package org.enso.table.data.table.join;
|
||||
import org.enso.base.arrays.IntArrayBuilder;
|
||||
import org.enso.table.data.mask.OrderMask;
|
||||
|
||||
public record JoinResult(int[] matchedRowsLeftIndices, int[] matchedRowsRightIndices) {
|
||||
public class JoinResult {
|
||||
private final int length;
|
||||
private final int[] leftIndices;
|
||||
private final int[] rightIndices;
|
||||
|
||||
public JoinResult(int[] leftIndices, int[] rightIndices, int length) {
|
||||
this.length = length;
|
||||
this.leftIndices = leftIndices;
|
||||
this.rightIndices = rightIndices;
|
||||
}
|
||||
|
||||
// ** Represents a pair of indices of matched rows. -1 means an unmatched row.*/
|
||||
public record RowPair(int leftIndex, int rightIndex) {}
|
||||
|
||||
public OrderMask getLeftOrderMask() {
|
||||
return new OrderMask(matchedRowsLeftIndices);
|
||||
return OrderMask.fromArray(leftIndices, length);
|
||||
}
|
||||
|
||||
public OrderMask getRightOrderMask() {
|
||||
return new OrderMask(matchedRowsRightIndices);
|
||||
return OrderMask.fromArray(rightIndices, length);
|
||||
}
|
||||
|
||||
public record BuilderSettings(
|
||||
boolean wantsCommon, boolean wantsLeftUnmatched, boolean wantsRightUnmatched) {}
|
||||
|
||||
public static class Builder {
|
||||
IntArrayBuilder leftIndices;
|
||||
IntArrayBuilder rightIndices;
|
||||
|
||||
final BuilderSettings settings;
|
||||
|
||||
public Builder(int initialCapacity, BuilderSettings settings) {
|
||||
public Builder(int initialCapacity) {
|
||||
leftIndices = new IntArrayBuilder(initialCapacity);
|
||||
rightIndices = new IntArrayBuilder(initialCapacity);
|
||||
this.settings = settings;
|
||||
}
|
||||
|
||||
public Builder(BuilderSettings settings) {
|
||||
this(128, settings);
|
||||
public Builder() {
|
||||
this(128);
|
||||
}
|
||||
|
||||
public void addMatchedRowsPair(int leftIndex, int rightIndex) {
|
||||
@ -47,8 +53,22 @@ public record JoinResult(int[] matchedRowsLeftIndices, int[] matchedRowsRightInd
|
||||
rightIndices.add(rightIndex);
|
||||
}
|
||||
|
||||
public JoinResult build() {
|
||||
return new JoinResult(leftIndices.build(), rightIndices.build());
|
||||
/**
|
||||
* Returns the result of the builder.
|
||||
*
|
||||
* <p>This method avoids copying for performance. After calling this method, the builder is
|
||||
* invalidated and cannot be used anymore. Any usage of the builder afterwards will result in a
|
||||
* {@code NullPointerException}.
|
||||
*/
|
||||
public JoinResult buildAndInvalidate() {
|
||||
var left = leftIndices;
|
||||
var right = rightIndices;
|
||||
leftIndices = null;
|
||||
rightIndices = null;
|
||||
return new JoinResult(
|
||||
left.unsafeGetResultAndInvalidate(),
|
||||
right.unsafeGetResultAndInvalidate(),
|
||||
left.getLength());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -15,8 +15,6 @@ public interface JoinStrategy {
|
||||
static JoinStrategy createStrategy(List<JoinCondition> conditions, JoinKind joinKind) {
|
||||
ensureConditionsNotEmpty(conditions);
|
||||
|
||||
JoinResult.BuilderSettings builderSettings = JoinKind.makeSettings(joinKind);
|
||||
|
||||
List<HashableCondition> hashableConditions =
|
||||
conditions.stream()
|
||||
.filter(c -> c instanceof HashableCondition)
|
||||
@ -31,12 +29,14 @@ public interface JoinStrategy {
|
||||
|
||||
if (hashableConditions.isEmpty()) {
|
||||
assert !betweenConditions.isEmpty();
|
||||
return new SortJoin(betweenConditions, builderSettings);
|
||||
return new SortJoin(betweenConditions, joinKind);
|
||||
} else if (betweenConditions.isEmpty()) {
|
||||
return new HashJoin(hashableConditions, new MatchAllStrategy(), builderSettings);
|
||||
} else {
|
||||
return new HashJoin(
|
||||
hashableConditions, new SortJoin(betweenConditions, builderSettings), builderSettings);
|
||||
hashableConditions,
|
||||
joinKind.wantsCommon ? new MatchAllStrategy() : new NoOpStrategy(),
|
||||
joinKind);
|
||||
} else {
|
||||
return new HashJoin(hashableConditions, new SortJoin(betweenConditions, joinKind), joinKind);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -15,10 +15,6 @@ public class MatchAllStrategy implements PluggableJoinStrategy {
|
||||
List<Integer> rightGroup,
|
||||
JoinResult.Builder resultBuilder,
|
||||
ProblemAggregator problemAggregator) {
|
||||
if (!resultBuilder.settings.wantsCommon()) {
|
||||
return;
|
||||
}
|
||||
|
||||
Context context = Context.getCurrent();
|
||||
for (var leftRow : leftGroup) {
|
||||
for (var rightRow : rightGroup) {
|
||||
|
@ -0,0 +1,15 @@
|
||||
package org.enso.table.data.table.join;
|
||||
|
||||
import java.util.List;
|
||||
import org.enso.table.problems.ProblemAggregator;
|
||||
|
||||
public class NoOpStrategy implements PluggableJoinStrategy {
|
||||
@Override
|
||||
public void joinSubsets(
|
||||
List<Integer> leftGroup,
|
||||
List<Integer> rightGroup,
|
||||
JoinResult.Builder resultBuilder,
|
||||
ProblemAggregator problemAggregator) {
|
||||
return;
|
||||
}
|
||||
}
|
@ -7,6 +7,7 @@ import java.util.List;
|
||||
import org.enso.base.ObjectComparator;
|
||||
import org.enso.table.data.column.storage.Storage;
|
||||
import org.enso.table.data.index.OrderedMultiValueKey;
|
||||
import org.enso.table.data.table.join.JoinKind;
|
||||
import org.enso.table.data.table.join.JoinResult;
|
||||
import org.enso.table.data.table.join.JoinStrategy;
|
||||
import org.enso.table.data.table.join.PluggableJoinStrategy;
|
||||
@ -16,9 +17,9 @@ import org.graalvm.polyglot.Context;
|
||||
|
||||
public class SortJoin implements JoinStrategy, PluggableJoinStrategy {
|
||||
|
||||
public SortJoin(List<Between> conditions, JoinResult.BuilderSettings resultBuilderSettings) {
|
||||
public SortJoin(List<Between> conditions, JoinKind joinKind) {
|
||||
JoinStrategy.ensureConditionsNotEmpty(conditions);
|
||||
this.resultBuilderSettings = resultBuilderSettings;
|
||||
this.joinKind = joinKind;
|
||||
|
||||
Context context = Context.getCurrent();
|
||||
int nConditions = conditions.size();
|
||||
@ -35,7 +36,7 @@ public class SortJoin implements JoinStrategy, PluggableJoinStrategy {
|
||||
}
|
||||
}
|
||||
|
||||
private final JoinResult.BuilderSettings resultBuilderSettings;
|
||||
private final JoinKind joinKind;
|
||||
|
||||
private final int[] directions;
|
||||
private final Storage<?>[] leftStorages;
|
||||
@ -46,13 +47,13 @@ public class SortJoin implements JoinStrategy, PluggableJoinStrategy {
|
||||
@Override
|
||||
public JoinResult join(ProblemAggregator problemAggregator) {
|
||||
Context context = Context.getCurrent();
|
||||
JoinResult.Builder resultBuilder = new JoinResult.Builder(resultBuilderSettings);
|
||||
JoinResult.Builder resultBuilder = new JoinResult.Builder();
|
||||
|
||||
int leftRowCount = leftStorages[0].size();
|
||||
int rightRowCount = lowerStorages[0].size();
|
||||
if (leftRowCount == 0 || rightRowCount == 0) {
|
||||
// if one group is completely empty, there will be no matches to report
|
||||
return resultBuilder.build();
|
||||
return resultBuilder.buildAndInvalidate();
|
||||
}
|
||||
List<OrderedMultiValueKey> leftKeys = new ArrayList<>(leftRowCount);
|
||||
for (int i = 0; i < leftRowCount; i++) {
|
||||
@ -64,13 +65,13 @@ public class SortJoin implements JoinStrategy, PluggableJoinStrategy {
|
||||
|
||||
for (int rightRowIx = 0; rightRowIx < rightRowCount; rightRowIx++) {
|
||||
int matches = addMatchingLeftRows(leftIndex, rightRowIx, resultBuilder);
|
||||
if (resultBuilderSettings.wantsRightUnmatched() && matches == 0) {
|
||||
if (joinKind.wantsRightUnmatched && matches == 0) {
|
||||
resultBuilder.addUnmatchedRightRow(rightRowIx);
|
||||
}
|
||||
context.safepoint();
|
||||
}
|
||||
|
||||
if (resultBuilderSettings.wantsLeftUnmatched()) {
|
||||
if (joinKind.wantsLeftUnmatched) {
|
||||
for (int leftRowIx = 0; leftRowIx < leftRowCount; leftRowIx++) {
|
||||
if (!matchedLeftRows.get(leftRowIx)) {
|
||||
resultBuilder.addUnmatchedLeftRow(leftRowIx);
|
||||
@ -79,7 +80,7 @@ public class SortJoin implements JoinStrategy, PluggableJoinStrategy {
|
||||
}
|
||||
}
|
||||
|
||||
return resultBuilder.build();
|
||||
return resultBuilder.buildAndInvalidate();
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -103,13 +104,13 @@ public class SortJoin implements JoinStrategy, PluggableJoinStrategy {
|
||||
|
||||
for (int rightRowIx : rightGroup) {
|
||||
int matches = addMatchingLeftRows(leftIndex, rightRowIx, resultBuilder);
|
||||
if (resultBuilderSettings.wantsRightUnmatched() && matches == 0) {
|
||||
if (joinKind.wantsRightUnmatched && matches == 0) {
|
||||
resultBuilder.addUnmatchedRightRow(rightRowIx);
|
||||
}
|
||||
context.safepoint();
|
||||
}
|
||||
|
||||
if (resultBuilderSettings.wantsLeftUnmatched()) {
|
||||
if (joinKind.wantsLeftUnmatched) {
|
||||
for (int leftRowIx : leftGroup) {
|
||||
if (!matchedLeftRows.get(leftRowIx)) {
|
||||
resultBuilder.addUnmatchedLeftRow(leftRowIx);
|
||||
@ -161,10 +162,10 @@ public class SortJoin implements JoinStrategy, PluggableJoinStrategy {
|
||||
if (isInRange(key, lowerBound, upperBound)) {
|
||||
int leftRowIx = key.getRowIndex();
|
||||
matchCount++;
|
||||
if (resultBuilderSettings.wantsCommon()) {
|
||||
if (joinKind.wantsCommon) {
|
||||
resultBuilder.addMatchedRowsPair(leftRowIx, rightRowIx);
|
||||
}
|
||||
if (resultBuilderSettings.wantsLeftUnmatched()) {
|
||||
if (joinKind.wantsLeftUnmatched) {
|
||||
matchedLeftRows.set(leftRowIx);
|
||||
}
|
||||
}
|
||||
|
@ -5,6 +5,7 @@ import org.enso.base.text.TextFoldingStrategy;
|
||||
import org.enso.table.data.index.MultiValueIndex;
|
||||
import org.enso.table.data.index.UnorderedMultiValueKey;
|
||||
import org.enso.table.data.table.Column;
|
||||
import org.enso.table.data.table.join.JoinKind;
|
||||
import org.enso.table.data.table.join.JoinResult;
|
||||
import org.enso.table.data.table.join.JoinStrategy;
|
||||
import org.enso.table.data.table.join.PluggableJoinStrategy;
|
||||
@ -24,10 +25,10 @@ public class HashJoin implements JoinStrategy {
|
||||
public HashJoin(
|
||||
List<HashableCondition> conditions,
|
||||
PluggableJoinStrategy remainingMatcher,
|
||||
JoinResult.BuilderSettings resultBuilderSettings) {
|
||||
JoinKind joinKind) {
|
||||
JoinStrategy.ensureConditionsNotEmpty(conditions);
|
||||
this.remainingMatcher = remainingMatcher;
|
||||
this.resultBuilderSettings = resultBuilderSettings;
|
||||
this.joinKind = joinKind;
|
||||
|
||||
List<HashEqualityCondition> equalConditions =
|
||||
conditions.stream().map(HashJoin::makeHashEqualityCondition).toList();
|
||||
@ -46,7 +47,7 @@ public class HashJoin implements JoinStrategy {
|
||||
private final Column[] leftEquals, rightEquals;
|
||||
private final List<TextFoldingStrategy> textFoldingStrategies;
|
||||
private final PluggableJoinStrategy remainingMatcher;
|
||||
private final JoinResult.BuilderSettings resultBuilderSettings;
|
||||
private final JoinKind joinKind;
|
||||
|
||||
@Override
|
||||
public JoinResult join(ProblemAggregator problemAggregator) {
|
||||
@ -59,7 +60,7 @@ public class HashJoin implements JoinStrategy {
|
||||
MultiValueIndex.makeUnorderedIndex(
|
||||
rightEquals, rightEquals[0].getSize(), textFoldingStrategies, problemAggregator);
|
||||
|
||||
JoinResult.Builder resultBuilder = new JoinResult.Builder(resultBuilderSettings);
|
||||
JoinResult.Builder resultBuilder = new JoinResult.Builder();
|
||||
for (var leftEntry : leftIndex.mapping().entrySet()) {
|
||||
UnorderedMultiValueKey leftKey = leftEntry.getKey();
|
||||
List<Integer> leftRows = leftEntry.getValue();
|
||||
@ -68,7 +69,7 @@ public class HashJoin implements JoinStrategy {
|
||||
if (rightRows != null) {
|
||||
remainingMatcher.joinSubsets(leftRows, rightRows, resultBuilder, problemAggregator);
|
||||
} else {
|
||||
if (resultBuilderSettings.wantsLeftUnmatched()) {
|
||||
if (joinKind.wantsLeftUnmatched) {
|
||||
for (int leftRow : leftRows) {
|
||||
resultBuilder.addUnmatchedLeftRow(leftRow);
|
||||
context.safepoint();
|
||||
@ -79,7 +80,7 @@ public class HashJoin implements JoinStrategy {
|
||||
context.safepoint();
|
||||
}
|
||||
|
||||
if (resultBuilderSettings.wantsRightUnmatched()) {
|
||||
if (joinKind.wantsRightUnmatched) {
|
||||
for (var rightEntry : rightIndex.mapping().entrySet()) {
|
||||
UnorderedMultiValueKey rightKey = rightEntry.getKey();
|
||||
boolean wasCompletelyUnmatched = !leftIndex.contains(rightKey);
|
||||
@ -91,7 +92,7 @@ public class HashJoin implements JoinStrategy {
|
||||
}
|
||||
}
|
||||
|
||||
return resultBuilder.build();
|
||||
return resultBuilder.buildAndInvalidate();
|
||||
}
|
||||
|
||||
private static HashEqualityCondition makeHashEqualityCondition(HashableCondition eq) {
|
||||
|
@ -201,7 +201,7 @@ public class LookupJoin {
|
||||
@Override
|
||||
public Column build(int[] orderMask) {
|
||||
assert orderMask != null;
|
||||
return lookupColumn.applyMask(new OrderMask(orderMask));
|
||||
return lookupColumn.applyMask(OrderMask.fromArray(orderMask));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -71,17 +71,6 @@ public class OrderBuilder {
|
||||
|
||||
int[] positions =
|
||||
IntStream.range(0, size).boxed().sorted(comparator).mapToInt(i -> i).toArray();
|
||||
return new OrderMask(positions);
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds an order mask based that will reverse the order of the data being masked.
|
||||
*
|
||||
* @param size the length of the data being masked
|
||||
* @return an order mask that will result in reversing the data it is applied to
|
||||
*/
|
||||
public static OrderMask buildReversedMask(int size) {
|
||||
int[] positions = IntStream.range(0, size).map(i -> size - i - 1).toArray();
|
||||
return new OrderMask(positions);
|
||||
return OrderMask.fromArray(positions);
|
||||
}
|
||||
}
|
||||
|
@ -153,8 +153,8 @@ create_scenario_antijoin num_rows =
|
||||
## This is a scenario where we join a very large table with a much smaller table
|
||||
to check an optimisation where we only index the smaller of the 2 tables
|
||||
create_scenario_large_small_table =
|
||||
xs = (0.up_to 1000000).map _-> Random.integer 0 99
|
||||
ys = (0.up_to 100).to_vector
|
||||
xs = (0.up_to 10000000).map _-> Random.integer 0 999
|
||||
ys = (0.up_to 1000).to_vector
|
||||
table1 = Table.new [["key", xs]]
|
||||
table2 = Table.new [["key", ys]]
|
||||
Scenario.Value table1 table2
|
||||
@ -228,12 +228,12 @@ collect_benches = Bench.build builder->
|
||||
r = scenario.table2.join scenario.table1 on="key" join_kind=Join_Kind.Left_Exclusive
|
||||
assert (r.row_count == 1000)
|
||||
|
||||
if extended_tests then group_builder.specify "Join_Large_Table_to_Small_Table" <|
|
||||
group_builder.specify "Join_Large_Table_to_Small_Table" <|
|
||||
scenario = data.large_small_table
|
||||
r = scenario.table1.join scenario.table2 on="key"
|
||||
assert (r.row_count == scenario.table1.row_count)
|
||||
|
||||
if extended_tests then group_builder.specify "Join_Small_Table_to_Large_Table" <|
|
||||
group_builder.specify "Join_Small_Table_to_Large_Table" <|
|
||||
scenario = data.large_small_table
|
||||
r = scenario.table2.join scenario.table1 on="key"
|
||||
assert (r.row_count == scenario.table1.row_count)
|
||||
|
Loading…
Reference in New Issue
Block a user