Improve performance of anti-join (#8338)

- Closes #8217
This commit is contained in:
Radosław Waśko 2023-11-24 03:44:57 +01:00 committed by GitHub
parent 4464a15035
commit c6b6384fe6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 182 additions and 104 deletions

View File

@ -1,3 +1,5 @@
polyglot java import org.enso.table.data.table.join.JoinKind as Java_Join_Kind
type Join_Kind type Join_Kind
## Returns only rows where a match between the left and right table is ## Returns only rows where a match between the left and right table is
found. If one row from the left table matches multiple rows in the right found. If one row from the left table matches multiple rows in the right
@ -36,3 +38,13 @@ type Join_Kind
In this mode, unlike in others, only columns of the right table are In this mode, unlike in others, only columns of the right table are
returned, since all columns of the left table would be all null anyway. returned, since all columns of the left table would be all null anyway.
Right_Exclusive Right_Exclusive
## PRIVATE
to_java : Java_Join_Kind
to_java self = case self of
Join_Kind.Inner -> Java_Join_Kind.INNER
Join_Kind.Left_Outer -> Java_Join_Kind.LEFT_OUTER
Join_Kind.Right_Outer -> Java_Join_Kind.RIGHT_OUTER
Join_Kind.Full -> Java_Join_Kind.FULL
Join_Kind.Left_Exclusive -> Java_Join_Kind.LEFT_ANTI
Join_Kind.Right_Exclusive -> Java_Join_Kind.RIGHT_ANTI

View File

@ -1808,15 +1808,6 @@ type Table
@on Widget_Helpers.make_join_condition_selector @on Widget_Helpers.make_join_condition_selector
join : Table -> Join_Kind -> Vector (Join_Condition | Text) | Text -> Text -> Problem_Behavior -> Table join : Table -> Join_Kind -> Vector (Join_Condition | Text) | Text -> Text -> Problem_Behavior -> Table
join self right:Table (join_kind : Join_Kind = Join_Kind.Left_Outer) on=[Join_Condition.Equals self.column_names.first] right_prefix="Right " on_problems=Report_Warning = Out_Of_Memory.handle_java_exception "join" <| join self right:Table (join_kind : Join_Kind = Join_Kind.Left_Outer) on=[Join_Condition.Equals self.column_names.first] right_prefix="Right " on_problems=Report_Warning = Out_Of_Memory.handle_java_exception "join" <|
# [left_unmatched, matched, right_unmatched]
rows_to_keep = case join_kind of
Join_Kind.Inner -> [False, True, False]
Join_Kind.Left_Outer -> [True, True, False]
Join_Kind.Right_Outer -> [False, True, True]
Join_Kind.Full -> [True, True, True]
Join_Kind.Left_Exclusive -> [True, False, False]
Join_Kind.Right_Exclusive -> [False, False, True]
columns_to_keep = case join_kind of columns_to_keep = case join_kind of
Join_Kind.Left_Exclusive -> [True, False] Join_Kind.Left_Exclusive -> [True, False]
Join_Kind.Right_Exclusive -> [False, True] Join_Kind.Right_Exclusive -> [False, True]
@ -1827,7 +1818,7 @@ type Table
java_conditions = join_resolution.conditions java_conditions = join_resolution.conditions
new_java_table = Java_Problems.with_problem_aggregator on_problems java_aggregator-> new_java_table = Java_Problems.with_problem_aggregator on_problems java_aggregator->
self.java_table.join right.java_table java_conditions (rows_to_keep.at 0) (rows_to_keep.at 1) (rows_to_keep.at 2) (columns_to_keep.at 0) (columns_to_keep.at 1) right_columns_to_drop right_prefix java_aggregator self.java_table.join right.java_table java_conditions join_kind.to_java (columns_to_keep.at 0) (columns_to_keep.at 1) right_columns_to_drop right_prefix java_aggregator
Table.Value new_java_table Table.Value new_java_table
## ALIAS cartesian join ## ALIAS cartesian join

View File

@ -18,6 +18,7 @@ import org.enso.table.data.index.OrderedMultiValueKey;
import org.enso.table.data.mask.OrderMask; import org.enso.table.data.mask.OrderMask;
import org.enso.table.data.mask.SliceRange; import org.enso.table.data.mask.SliceRange;
import org.enso.table.data.table.join.CrossJoin; import org.enso.table.data.table.join.CrossJoin;
import org.enso.table.data.table.join.JoinKind;
import org.enso.table.data.table.join.conditions.JoinCondition; import org.enso.table.data.table.join.conditions.JoinCondition;
import org.enso.table.data.table.join.JoinResult; import org.enso.table.data.table.join.JoinResult;
import org.enso.table.data.table.join.JoinStrategy; import org.enso.table.data.table.join.JoinStrategy;
@ -269,58 +270,17 @@ public class Table {
* form one table. {@code rightColumnsToDrop} allows to drop columns from the right table that are redundant when * form one table. {@code rightColumnsToDrop} allows to drop columns from the right table that are redundant when
* joining on equality of equally named columns. * joining on equality of equally named columns.
*/ */
public Table join(Table right, List<JoinCondition> conditions, boolean keepLeftUnmatched, boolean keepMatched, public Table join(Table right, List<JoinCondition> conditions, JoinKind joinKind, boolean includeLeftColumns, boolean includeRightColumns,
boolean keepRightUnmatched, boolean includeLeftColumns, boolean includeRightColumns,
List<String> rightColumnsToDrop, String right_prefix, ProblemAggregator problemAggregator) { List<String> rightColumnsToDrop, String right_prefix, ProblemAggregator problemAggregator) {
Context context = Context.getCurrent();
NameDeduplicator nameDeduplicator = NameDeduplicator.createDefault(problemAggregator); NameDeduplicator nameDeduplicator = NameDeduplicator.createDefault(problemAggregator);
if (!keepLeftUnmatched && !keepMatched && !keepRightUnmatched) {
throw new IllegalArgumentException("At least one of keepLeftUnmatched, keepMatched or keepRightUnmatched must " +
"be true.");
}
JoinStrategy strategy = JoinStrategy.createStrategy(conditions); JoinStrategy strategy = JoinStrategy.createStrategy(conditions, joinKind);
JoinResult joinResult = strategy.join(problemAggregator); JoinResult joinResult = strategy.join(problemAggregator);
List<JoinResult> resultsToKeep = new ArrayList<>();
if (keepMatched) {
resultsToKeep.add(joinResult);
}
if (keepLeftUnmatched) {
Set<Integer> matchedLeftRows = joinResult.leftMatchedRows();
JoinResult.Builder leftUnmatchedBuilder = new JoinResult.Builder();
for (int i = 0; i < this.rowCount(); i++) {
if (!matchedLeftRows.contains(i)) {
leftUnmatchedBuilder.addRow(i, Index.NOT_FOUND);
}
context.safepoint();
}
resultsToKeep.add(leftUnmatchedBuilder.build());
}
if (keepRightUnmatched) {
Set<Integer> matchedRightRows = joinResult.rightMatchedRows();
JoinResult.Builder rightUnmatchedBuilder = new JoinResult.Builder();
for (int i = 0; i < right.rowCount(); i++) {
if (!matchedRightRows.contains(i)) {
rightUnmatchedBuilder.addRow(Index.NOT_FOUND, i);
}
context.safepoint();
}
resultsToKeep.add(rightUnmatchedBuilder.build());
}
List<Column> newColumns = new ArrayList<>(); List<Column> newColumns = new ArrayList<>();
if (includeLeftColumns) { if (includeLeftColumns) {
OrderMask leftMask = OrderMask leftMask = joinResult.getLeftOrderMask();
OrderMask.concat(resultsToKeep.stream().map(JoinResult::getLeftOrderMask).collect(Collectors.toList()));
for (Column column : this.columns) { for (Column column : this.columns) {
Column newColumn = column.applyMask(leftMask); Column newColumn = column.applyMask(leftMask);
newColumns.add(newColumn); newColumns.add(newColumn);
@ -328,14 +288,13 @@ public class Table {
} }
if (includeRightColumns) { if (includeRightColumns) {
OrderMask rightMask = OrderMask rightMask = joinResult.getRightOrderMask();
OrderMask.concat(resultsToKeep.stream().map(JoinResult::getRightOrderMask).collect(Collectors.toList())); List<String> leftColumnNames = newColumns.stream().map(Column::getName).toList();
List<String> leftColumnNames = newColumns.stream().map(Column::getName).collect(Collectors.toList());
HashSet<String> toDrop = new HashSet<>(rightColumnsToDrop); HashSet<String> toDrop = new HashSet<>(rightColumnsToDrop);
List<Column> rightColumnsToKeep = List<Column> rightColumnsToKeep =
Arrays.stream(right.getColumns()).filter(col -> !toDrop.contains(col.getName())).collect(Collectors.toList()); Arrays.stream(right.getColumns()).filter(col -> !toDrop.contains(col.getName())).toList();
List<String> rightColumNames = rightColumnsToKeep.stream().map(Column::getName).collect(Collectors.toList()); List<String> rightColumNames = rightColumnsToKeep.stream().map(Column::getName).toList();
List<String> newRightColumnNames = nameDeduplicator.combineWithPrefix(leftColumnNames, rightColumNames, List<String> newRightColumnNames = nameDeduplicator.combineWithPrefix(leftColumnNames, rightColumNames,
right_prefix); right_prefix);

View File

@ -5,10 +5,12 @@ import org.graalvm.polyglot.Context;
public class CrossJoin { public class CrossJoin {
public static JoinResult perform(int leftRowCount, int rightRowCount) { public static JoinResult perform(int leftRowCount, int rightRowCount) {
Context context = Context.getCurrent(); Context context = Context.getCurrent();
JoinResult.Builder resultBuilder = new JoinResult.Builder(leftRowCount * rightRowCount); JoinResult.BuilderSettings settings = new JoinResult.BuilderSettings(true, true, true);
JoinResult.Builder resultBuilder =
new JoinResult.Builder(leftRowCount * rightRowCount, settings);
for (int l = 0; l < leftRowCount; ++l) { for (int l = 0; l < leftRowCount; ++l) {
for (int r = 0; r < rightRowCount; ++r) { for (int r = 0; r < rightRowCount; ++r) {
resultBuilder.addRow(l, r); resultBuilder.addMatchedRowsPair(l, r);
context.safepoint(); context.safepoint();
} }
} }

View File

@ -0,0 +1,21 @@
package org.enso.table.data.table.join;
public enum JoinKind {
INNER,
FULL,
LEFT_OUTER,
RIGHT_OUTER,
LEFT_ANTI,
RIGHT_ANTI;
public static JoinResult.BuilderSettings makeSettings(JoinKind joinKind) {
return switch (joinKind) {
case INNER -> new JoinResult.BuilderSettings(true, false, false);
case FULL -> new JoinResult.BuilderSettings(true, true, true);
case LEFT_OUTER -> new JoinResult.BuilderSettings(true, true, false);
case RIGHT_OUTER -> new JoinResult.BuilderSettings(true, false, true);
case LEFT_ANTI -> new JoinResult.BuilderSettings(false, true, false);
case RIGHT_ANTI -> new JoinResult.BuilderSettings(false, false, true);
};
}
}

View File

@ -3,11 +3,6 @@ package org.enso.table.data.table.join;
import org.enso.base.arrays.IntArrayBuilder; import org.enso.base.arrays.IntArrayBuilder;
import org.enso.table.data.mask.OrderMask; import org.enso.table.data.mask.OrderMask;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;
import java.util.stream.Collectors;
public record JoinResult(int[] matchedRowsLeftIndices, int[] matchedRowsRightIndices) { public record JoinResult(int[] matchedRowsLeftIndices, int[] matchedRowsRightIndices) {
public OrderMask getLeftOrderMask() { public OrderMask getLeftOrderMask() {
@ -18,32 +13,39 @@ public record JoinResult(int[] matchedRowsLeftIndices, int[] matchedRowsRightInd
return new OrderMask(matchedRowsRightIndices); return new OrderMask(matchedRowsRightIndices);
} }
public Set<Integer> leftMatchedRows() { public record BuilderSettings(boolean wantsCommon, boolean wantsLeftUnmatched, boolean wantsRightUnmatched) {}
return new HashSet<>(Arrays.stream(matchedRowsLeftIndices).boxed().collect(Collectors.toList()));
}
public Set<Integer> rightMatchedRows() {
return new HashSet<>(Arrays.stream(matchedRowsRightIndices).boxed().collect(Collectors.toList()));
}
public static class Builder { public static class Builder {
IntArrayBuilder leftIndices; IntArrayBuilder leftIndices;
IntArrayBuilder rightIndices; IntArrayBuilder rightIndices;
public Builder(int initialCapacity) { final BuilderSettings settings;
public Builder(int initialCapacity, BuilderSettings settings) {
leftIndices = new IntArrayBuilder(initialCapacity); leftIndices = new IntArrayBuilder(initialCapacity);
rightIndices = new IntArrayBuilder(initialCapacity); rightIndices = new IntArrayBuilder(initialCapacity);
this.settings = settings;
} }
public Builder() { public Builder(BuilderSettings settings) {
this(128); this(128, settings);
} }
public void addRow(int leftIndex, int rightIndex) { public void addMatchedRowsPair(int leftIndex, int rightIndex) {
leftIndices.add(leftIndex); leftIndices.add(leftIndex);
rightIndices.add(rightIndex); rightIndices.add(rightIndex);
} }
public void addUnmatchedLeftRow(int leftIndex) {
leftIndices.add(leftIndex);
rightIndices.add(-1);
}
public void addUnmatchedRightRow(int rightIndex) {
leftIndices.add(-1);
rightIndices.add(rightIndex);
}
public JoinResult build() { public JoinResult build() {
return new JoinResult(leftIndices.build(), rightIndices.build()); return new JoinResult(leftIndices.build(), rightIndices.build());
} }

View File

@ -17,11 +17,13 @@ import java.util.List;
public interface JoinStrategy { public interface JoinStrategy {
JoinResult join(ProblemAggregator problemAggregator); JoinResult join(ProblemAggregator problemAggregator);
static JoinStrategy createStrategy(List<JoinCondition> conditions) { static JoinStrategy createStrategy(List<JoinCondition> conditions, JoinKind joinKind) {
if (conditions.isEmpty()) { if (conditions.isEmpty()) {
throw new IllegalArgumentException("At least one join condition must be provided."); throw new IllegalArgumentException("At least one join condition must be provided.");
} }
JoinResult.BuilderSettings builderSettings = JoinKind.makeSettings(joinKind);
List<HashableCondition> hashableConditions = conditions.stream() List<HashableCondition> hashableConditions = conditions.stream()
.filter(c -> c instanceof HashableCondition) .filter(c -> c instanceof HashableCondition)
.map(c -> (HashableCondition) c) .map(c -> (HashableCondition) c)
@ -37,11 +39,11 @@ public interface JoinStrategy {
if (hashableConditions.isEmpty()) { if (hashableConditions.isEmpty()) {
assert !betweenConditions.isEmpty(); assert !betweenConditions.isEmpty();
return new SortJoin(betweenConditions); return new SortJoin(betweenConditions, builderSettings);
} else if (betweenConditions.isEmpty()) { } else if (betweenConditions.isEmpty()) {
return new HashJoin(hashableConditions, new MatchAllStrategy()); return new HashJoin(hashableConditions, new MatchAllStrategy(), builderSettings);
} else { } else {
return new HashJoin(hashableConditions, new SortJoin(betweenConditions)); return new HashJoin(hashableConditions, new SortJoin(betweenConditions, builderSettings), builderSettings);
} }
} }

View File

@ -15,10 +15,14 @@ public class MatchAllStrategy implements PluggableJoinStrategy {
List<Integer> rightGroup, List<Integer> rightGroup,
JoinResult.Builder resultBuilder, JoinResult.Builder resultBuilder,
ProblemAggregator problemAggregator) { ProblemAggregator problemAggregator) {
if (!resultBuilder.settings.wantsCommon()) {
return;
}
Context context = Context.getCurrent(); Context context = Context.getCurrent();
for (var leftRow : leftGroup) { for (var leftRow : leftGroup) {
for (var rightRow : rightGroup) { for (var rightRow : rightGroup) {
resultBuilder.addRow(leftRow, rightRow); resultBuilder.addMatchedRowsPair(leftRow, rightRow);
context.safepoint(); context.safepoint();
} }

View File

@ -1,6 +1,7 @@
package org.enso.table.data.table.join.between; package org.enso.table.data.table.join.between;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.BitSet;
import java.util.Comparator; import java.util.Comparator;
import java.util.List; import java.util.List;
import org.enso.base.ObjectComparator; import org.enso.base.ObjectComparator;
@ -15,8 +16,9 @@ import org.graalvm.polyglot.Context;
public class SortJoin implements JoinStrategy, PluggableJoinStrategy { public class SortJoin implements JoinStrategy, PluggableJoinStrategy {
public SortJoin(List<Between> conditions) { public SortJoin(List<Between> conditions, JoinResult.BuilderSettings resultBuilderSettings) {
conditionsHelper = new JoinStrategy.ConditionsHelper(conditions); conditionsHelper = new JoinStrategy.ConditionsHelper(conditions);
this.resultBuilderSettings = resultBuilderSettings;
Context context = Context.getCurrent(); Context context = Context.getCurrent();
int nConditions = conditions.size(); int nConditions = conditions.size();
@ -34,16 +36,18 @@ public class SortJoin implements JoinStrategy, PluggableJoinStrategy {
} }
private final JoinStrategy.ConditionsHelper conditionsHelper; private final JoinStrategy.ConditionsHelper conditionsHelper;
private final JoinResult.BuilderSettings resultBuilderSettings;
private final int[] directions; private final int[] directions;
private final Storage<?>[] leftStorages; private final Storage<?>[] leftStorages;
private final Storage<?>[] lowerStorages; private final Storage<?>[] lowerStorages;
private final Storage<?>[] upperStorages; private final Storage<?>[] upperStorages;
private final BitSet matchedLeftRows = new BitSet();
@Override @Override
public JoinResult join(ProblemAggregator problemAggregator) { public JoinResult join(ProblemAggregator problemAggregator) {
Context context = Context.getCurrent(); Context context = Context.getCurrent();
JoinResult.Builder resultBuilder = new JoinResult.Builder(); JoinResult.Builder resultBuilder = new JoinResult.Builder(resultBuilderSettings);
int leftRowCount = conditionsHelper.getLeftTableRowCount(); int leftRowCount = conditionsHelper.getLeftTableRowCount();
int rightRowCount = conditionsHelper.getRightTableRowCount(); int rightRowCount = conditionsHelper.getRightTableRowCount();
@ -60,10 +64,22 @@ public class SortJoin implements JoinStrategy, PluggableJoinStrategy {
SortedListIndex<OrderedMultiValueKey> leftIndex = buildSortedLeftIndex(leftKeys); SortedListIndex<OrderedMultiValueKey> leftIndex = buildSortedLeftIndex(leftKeys);
for (int rightRowIx = 0; rightRowIx < rightRowCount; rightRowIx++) { for (int rightRowIx = 0; rightRowIx < rightRowCount; rightRowIx++) {
addMatchingLeftRows(leftIndex, rightRowIx, resultBuilder); int matches = addMatchingLeftRows(leftIndex, rightRowIx, resultBuilder);
if (resultBuilderSettings.wantsRightUnmatched() && matches == 0) {
resultBuilder.addUnmatchedRightRow(rightRowIx);
}
context.safepoint(); context.safepoint();
} }
if (resultBuilderSettings.wantsLeftUnmatched()) {
for (int leftRowIx = 0; leftRowIx < leftRowCount; leftRowIx++) {
if (!matchedLeftRows.get(leftRowIx)) {
resultBuilder.addUnmatchedLeftRow(leftRowIx);
}
context.safepoint();
}
}
return resultBuilder.build(); return resultBuilder.build();
} }
@ -87,9 +103,21 @@ public class SortJoin implements JoinStrategy, PluggableJoinStrategy {
SortedListIndex<OrderedMultiValueKey> leftIndex = buildSortedLeftIndex(leftKeys); SortedListIndex<OrderedMultiValueKey> leftIndex = buildSortedLeftIndex(leftKeys);
for (int rightRowIx : rightGroup) { for (int rightRowIx : rightGroup) {
addMatchingLeftRows(leftIndex, rightRowIx, resultBuilder); int matches = addMatchingLeftRows(leftIndex, rightRowIx, resultBuilder);
if (resultBuilderSettings.wantsRightUnmatched() && matches == 0) {
resultBuilder.addUnmatchedRightRow(rightRowIx);
}
context.safepoint(); context.safepoint();
} }
if (resultBuilderSettings.wantsLeftUnmatched()) {
for (int leftRowIx : leftGroup) {
if (!matchedLeftRows.get(leftRowIx)) {
resultBuilder.addUnmatchedLeftRow(leftRowIx);
}
context.safepoint();
}
}
} }
private SortedListIndex<OrderedMultiValueKey> buildSortedLeftIndex( private SortedListIndex<OrderedMultiValueKey> buildSortedLeftIndex(
@ -105,7 +133,13 @@ public class SortJoin implements JoinStrategy, PluggableJoinStrategy {
return new OrderedMultiValueKey(upperStorages, rightRowIx, directions, objectComparator); return new OrderedMultiValueKey(upperStorages, rightRowIx, directions, objectComparator);
} }
private void addMatchingLeftRows( /**
* Adds all pairs of rows from the left index matching the right index to the builder, and reports
* the match count.
*
* <p>It also marks any of the left rows that were matched, in the {@code matchedLeftRows}.
*/
private int addMatchingLeftRows(
SortedListIndex<OrderedMultiValueKey> sortedLeftIndex, SortedListIndex<OrderedMultiValueKey> sortedLeftIndex,
int rightRowIx, int rightRowIx,
JoinResult.Builder resultBuilder) { JoinResult.Builder resultBuilder) {
@ -116,19 +150,30 @@ public class SortJoin implements JoinStrategy, PluggableJoinStrategy {
if (lowerBound.hasAnyNulls() if (lowerBound.hasAnyNulls()
|| upperBound.hasAnyNulls() || upperBound.hasAnyNulls()
|| lowerBound.compareTo(upperBound) > 0) { || lowerBound.compareTo(upperBound) > 0) {
return; return 0;
} }
int matchCount = 0;
List<OrderedMultiValueKey> firstCoordinateMatches = List<OrderedMultiValueKey> firstCoordinateMatches =
sortedLeftIndex.findSubRange(lowerBound, upperBound); sortedLeftIndex.findSubRange(lowerBound, upperBound);
Context context = Context.getCurrent(); Context context = Context.getCurrent();
for (OrderedMultiValueKey key : firstCoordinateMatches) { for (OrderedMultiValueKey key : firstCoordinateMatches) {
if (isInRange(key, lowerBound, upperBound)) { if (isInRange(key, lowerBound, upperBound)) {
resultBuilder.addRow(key.getRowIndex(), rightRowIx); int leftRowIx = key.getRowIndex();
matchCount++;
if (resultBuilderSettings.wantsCommon()) {
resultBuilder.addMatchedRowsPair(leftRowIx, rightRowIx);
}
if (resultBuilderSettings.wantsLeftUnmatched()) {
matchedLeftRows.set(leftRowIx);
}
} }
context.safepoint(); context.safepoint();
} }
return matchCount;
} }
private boolean isInRange( private boolean isInRange(

View File

@ -13,6 +13,7 @@ import org.enso.table.data.table.join.conditions.HashableCondition;
import org.enso.table.problems.ProblemAggregator; import org.enso.table.problems.ProblemAggregator;
import org.graalvm.polyglot.Context; import org.graalvm.polyglot.Context;
import java.util.HashMap;
import java.util.List; import java.util.List;
/** /**
@ -22,9 +23,10 @@ import java.util.List;
* subsets. * subsets.
*/ */
public class HashJoin implements JoinStrategy { public class HashJoin implements JoinStrategy {
public HashJoin(List<HashableCondition> conditions, PluggableJoinStrategy remainingMatcher) { public HashJoin(List<HashableCondition> conditions, PluggableJoinStrategy remainingMatcher, JoinResult.BuilderSettings resultBuilderSettings) {
conditionsHelper = new JoinStrategy.ConditionsHelper(conditions); conditionsHelper = new JoinStrategy.ConditionsHelper(conditions);
this.remainingMatcher = remainingMatcher; this.remainingMatcher = remainingMatcher;
this.resultBuilderSettings = resultBuilderSettings;
List<HashEqualityCondition> equalConditions = List<HashEqualityCondition> equalConditions =
conditions.stream().map(HashJoin::makeHashEqualityCondition).toList(); conditions.stream().map(HashJoin::makeHashEqualityCondition).toList();
@ -42,6 +44,7 @@ public class HashJoin implements JoinStrategy {
private final Column[] leftEquals, rightEquals; private final Column[] leftEquals, rightEquals;
private final List<TextFoldingStrategy> textFoldingStrategies; private final List<TextFoldingStrategy> textFoldingStrategies;
private final PluggableJoinStrategy remainingMatcher; private final PluggableJoinStrategy remainingMatcher;
private final JoinResult.BuilderSettings resultBuilderSettings;
@Override @Override
public JoinResult join(ProblemAggregator problemAggregator) { public JoinResult join(ProblemAggregator problemAggregator) {
@ -52,7 +55,7 @@ public class HashJoin implements JoinStrategy {
var rightIndex = MultiValueIndex.makeUnorderedIndex(rightEquals, conditionsHelper.getRightTableRowCount(), var rightIndex = MultiValueIndex.makeUnorderedIndex(rightEquals, conditionsHelper.getRightTableRowCount(),
textFoldingStrategies, problemAggregator); textFoldingStrategies, problemAggregator);
JoinResult.Builder resultBuilder = new JoinResult.Builder(); JoinResult.Builder resultBuilder = new JoinResult.Builder(resultBuilderSettings);
for (var leftEntry : leftIndex.mapping().entrySet()) { for (var leftEntry : leftIndex.mapping().entrySet()) {
UnorderedMultiValueKey leftKey = leftEntry.getKey(); UnorderedMultiValueKey leftKey = leftEntry.getKey();
List<Integer> leftRows = leftEntry.getValue(); List<Integer> leftRows = leftEntry.getValue();
@ -60,11 +63,30 @@ public class HashJoin implements JoinStrategy {
if (rightRows != null) { if (rightRows != null) {
remainingMatcher.joinSubsets(leftRows, rightRows, resultBuilder, problemAggregator); remainingMatcher.joinSubsets(leftRows, rightRows, resultBuilder, problemAggregator);
} else {
if (resultBuilderSettings.wantsLeftUnmatched()) {
for (int leftRow : leftRows) {
resultBuilder.addUnmatchedLeftRow(leftRow);
context.safepoint();
}
}
} }
context.safepoint(); context.safepoint();
} }
if (resultBuilderSettings.wantsRightUnmatched()) {
for (var rightEntry : rightIndex.mapping().entrySet()) {
UnorderedMultiValueKey rightKey = rightEntry.getKey();
boolean wasCompletelyUnmatched = !leftIndex.contains(rightKey);
if (wasCompletelyUnmatched) {
for (int rightRow : rightEntry.getValue()) {
resultBuilder.addUnmatchedRightRow(rightRow);
}
}
}
}
return resultBuilder.build(); return resultBuilder.build();
} }

View File

@ -213,8 +213,7 @@ collect_benches = Bench.build builder->
r = scenario.table1.join t2 on=[Join_Condition.Between "x" "x_lows" "x_highs", Join_Condition.Between "y" "y_lows" "y_highs"] r = scenario.table1.join t2 on=[Join_Condition.Between "x" "x_lows" "x_highs", Join_Condition.Between "y" "y_lows" "y_highs"]
assert (r.row_count == scenario.table1.row_count) assert (r.row_count == scenario.table1.row_count)
# TODO this should be part of the main tests, but it was causing issues on CI; re-enable this with #8217 group_builder.specify "AntiJoin" <|
if extended_tests then group_builder.specify "AntiJoin" <|
scenario = data.antijoin scenario = data.antijoin
r = scenario.table2.join scenario.table1 on="key" join_kind=Join_Kind.Left_Exclusive r = scenario.table2.join scenario.table1 on="key" join_kind=Join_Kind.Left_Exclusive
assert (r.row_count == 1000) assert (r.row_count == 1000)

View File

@ -284,18 +284,37 @@ spec setup =
Test.specify "should allow to mix join conditions of various kinds" <| Test.specify "should allow to mix join conditions of various kinds" <|
t1 = table_builder [["X", [1, 12, 12, 0]], ["Y", [1, 2, 3, 4]], ["Z", ["a", "A", "a", "ą"]], ["W", [1, 2, 3, 4]]] t1 = table_builder [["X", [1, 12, 12, 0]], ["Y", [1, 2, 3, 4]], ["Z", ["a", "A", "a", "ą"]], ["W", [1, 2, 3, 4]]]
t2 = table_builder [["X", [12, 12, 1]], ["l", [0, 100, 100]], ["u", [10, 100, 100]], ["Z", ["A", "A", "A"]], ["W'", [10, 20, 30]]] t2 = table_builder [["X", [12, 12, 1]], ["l", [0, 100, 100]], ["u", [10, 100, 200]], ["Z", ["A", "A", "A"]], ["W'", [10, 20, 30]]]
r1 = t1.join t2 join_kind=Join_Kind.Inner on=[Join_Condition.Between "Y" "l" "u", Join_Condition.Equals_Ignore_Case "Z" "Z", Join_Condition.Equals "X" "X"] |> materialize |> _.order_by ["Y"] conditions = [Join_Condition.Between "Y" "l" "u", Join_Condition.Equals_Ignore_Case "Z" "Z", Join_Condition.Equals "X" "X"]
expect_column_names ["X", "Y", "Z", "W", "l", "u", "Right Z", "W'"] r1 r1 = t1.join t2 join_kind=Join_Kind.Inner on=conditions |> materialize |> _.order_by ["Y"]
r1.at "X" . to_vector . should_equal [12, 12] within_table r1 <|
r1.at "Y" . to_vector . should_equal [2, 3] r1.column_names.should_equal ["X", "Y", "Z", "W", "l", "u", "Right Z", "W'"]
r1.at "Z" . to_vector . should_equal ["A", "a"] r1.at "X" . to_vector . should_equal [12, 12]
r1.at "W" . to_vector . should_equal [2, 3] r1.at "Y" . to_vector . should_equal [2, 3]
r1.at "l" . to_vector . should_equal [0, 0] r1.at "Z" . to_vector . should_equal ["A", "a"]
r1.at "u" . to_vector . should_equal [10, 10] r1.at "W" . to_vector . should_equal [2, 3]
r1.at "Right Z" . to_vector . should_equal ["A", "A"] r1.at "l" . to_vector . should_equal [0, 0]
r1.at "W'" . to_vector . should_equal [10, 10] r1.at "u" . to_vector . should_equal [10, 10]
r1.at "Right Z" . to_vector . should_equal ["A", "A"]
r1.at "W'" . to_vector . should_equal [10, 10]
r2 = t1.join t2 join_kind=Join_Kind.Left_Exclusive on=conditions |> materialize |> _.order_by ["Y"]
within_table r2 <|
r2.column_names.should_equal ["X", "Y", "Z", "W"]
r2.at "X" . to_vector . should_equal [1, 0]
r2.at "Y" . to_vector . should_equal [1, 4]
r2.at "Z" . to_vector . should_equal ["a", "ą"]
r2.at "W" . to_vector . should_equal [1, 4]
r3 = t1.join t2 join_kind=Join_Kind.Right_Exclusive on=conditions |> materialize |> _.order_by ["W'"]
within_table r3 <|
r3.column_names.should_equal ["X", "l", "u", "Z", "W'"]
r3.at "X" . to_vector . should_equal [12, 1]
r3.at "l" . to_vector . should_equal [100, 100]
r3.at "u" . to_vector . should_equal [100, 200]
r3.at "Z" . to_vector . should_equal ["A", "A"]
r3.at "W'" . to_vector . should_equal [20, 30]
Test.specify "should work fine if the same condition is specified multiple times" <| Test.specify "should work fine if the same condition is specified multiple times" <|
r = t3.join t4 join_kind=Join_Kind.Inner on=["X", "X", "Y", "X", "Y"] |> materialize |> _.order_by ["X", "Y", "Z", "Right Z"] r = t3.join t4 join_kind=Join_Kind.Inner on=["X", "X", "Y", "X", "Y"] |> materialize |> _.order_by ["X", "Y", "Z", "Right Z"]