Make table.Running return integer typed columns for min/max (#9853)

* New Tests

* Green

* Running min for longs

* Unsupported types test

* Revert

* Add support for all the integer types

* Another test
This commit is contained in:
AdRiley 2024-05-07 12:49:12 +03:00 committed by GitHub
parent 930f3c593e
commit 15976a8505
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 240 additions and 24 deletions

View File

@ -5,6 +5,8 @@ import org.enso.base.polyglot.NumericConverter;
import org.enso.base.statistics.Statistic; import org.enso.base.statistics.Statistic;
import org.enso.table.data.column.storage.Storage; import org.enso.table.data.column.storage.Storage;
import org.enso.table.data.column.storage.numeric.DoubleStorage; import org.enso.table.data.column.storage.numeric.DoubleStorage;
import org.enso.table.data.column.storage.numeric.LongStorage;
import org.enso.table.data.column.storage.type.IntegerType;
import org.enso.table.data.table.Column; import org.enso.table.data.table.Column;
import org.enso.table.data.table.problems.IgnoredNaN; import org.enso.table.data.table.problems.IgnoredNaN;
import org.enso.table.data.table.problems.IgnoredNothing; import org.enso.table.data.table.problems.IgnoredNothing;
@ -31,7 +33,7 @@ public class AddRunning {
return runningStatistic.getResult(); return runningStatistic.getResult();
} }
private static RunningStatistic<Double> createRunningStatistic( private static RunningStatistic<?> createRunningStatistic(
Statistic statistic, Column sourceColumn, ProblemAggregator problemAggregator) { Statistic statistic, Column sourceColumn, ProblemAggregator problemAggregator) {
switch (statistic) { switch (statistic) {
case Sum -> { case Sum -> {
@ -41,38 +43,98 @@ public class AddRunning {
return new RunningMeanStatistic(sourceColumn, problemAggregator); return new RunningMeanStatistic(sourceColumn, problemAggregator);
} }
case Minimum -> { case Minimum -> {
if (sourceColumn.getStorage().getType() instanceof IntegerType type) {
return new RunningMinLongStatistic(sourceColumn, problemAggregator, type);
}
return new RunningMinStatistic(sourceColumn, problemAggregator); return new RunningMinStatistic(sourceColumn, problemAggregator);
} }
case Maximum -> { case Maximum -> {
if (sourceColumn.getStorage().getType() instanceof IntegerType type) {
return new RunningMaxLongStatistic(sourceColumn, problemAggregator, type);
}
return new RunningMaxStatistic(sourceColumn, problemAggregator); return new RunningMaxStatistic(sourceColumn, problemAggregator);
} }
default -> throw new IllegalArgumentException("Unsupported statistic: " + statistic); default -> throw new IllegalArgumentException("Unsupported statistic: " + statistic);
} }
} }
private abstract static class RunningStatisticBase implements RunningStatistic<Double> { private interface TypeHandler<T> {
T tryConvertingToType(Object o);
long typeToRawLongBits(T t);
Storage<T> createStorage(long[] result, int size, BitSet isNothing);
}
private static class DoubleHandler implements TypeHandler<Double> {
@Override
public Double tryConvertingToType(Object o) {
return NumericConverter.tryConvertingToDouble(o);
}
@Override
public long typeToRawLongBits(Double d) {
return Double.doubleToRawLongBits(d);
}
@Override
public Storage<Double> createStorage(long[] result, int size, BitSet isNothing) {
return new DoubleStorage(result, size, isNothing);
}
}
private static class LongHandler implements TypeHandler<Long> {
IntegerType type;
LongHandler(IntegerType type) {
this.type = type;
}
@Override
public Long tryConvertingToType(Object o) {
return NumericConverter.tryConvertingToLong(o);
}
@Override
public long typeToRawLongBits(Long l) {
return l;
}
@Override
public Storage<Long> createStorage(long[] result, int size, BitSet isNothing) {
return new LongStorage(result, size, isNothing, type);
}
}
private abstract static class RunningStatisticBase<T> implements RunningStatistic<T> {
long[] result; long[] result;
BitSet isNothing; BitSet isNothing;
ColumnAggregatedProblemAggregator columnAggregatedProblemAggregator; ColumnAggregatedProblemAggregator columnAggregatedProblemAggregator;
Column sourceColumn; Column sourceColumn;
TypeHandler<T> typeHandler;
RunningStatisticBase(Column sourceColumn, ProblemAggregator problemAggregator) { RunningStatisticBase(
Column sourceColumn, ProblemAggregator problemAggregator, TypeHandler<T> typeHandler) {
result = new long[sourceColumn.getSize()]; result = new long[sourceColumn.getSize()];
isNothing = new BitSet(); isNothing = new BitSet();
columnAggregatedProblemAggregator = new ColumnAggregatedProblemAggregator(problemAggregator); columnAggregatedProblemAggregator = new ColumnAggregatedProblemAggregator(problemAggregator);
this.sourceColumn = sourceColumn; this.sourceColumn = sourceColumn;
this.typeHandler = typeHandler;
} }
@Override @Override
public void calculateNextValue(int i, RunningIterator<Double> it) { public void calculateNextValue(int i, RunningIterator<T> it) {
Object value = sourceColumn.getStorage().getItemBoxed(i); Object value = sourceColumn.getStorage().getItemBoxed(i);
if (value == null) { if (value == null) {
columnAggregatedProblemAggregator.reportColumnAggregatedProblem( columnAggregatedProblemAggregator.reportColumnAggregatedProblem(
new IgnoredNothing(sourceColumn.getName(), i)); new IgnoredNothing(sourceColumn.getName(), i));
} }
Double dValue = NumericConverter.tryConvertingToDouble(value); T dValue = typeHandler.tryConvertingToType(value);
Double dNextValue; T dNextValue;
if (dValue != null && dValue.equals(Double.NaN)) { if (dValue != null && dValue.equals(Double.NaN)) {
columnAggregatedProblemAggregator.reportColumnAggregatedProblem( columnAggregatedProblemAggregator.reportColumnAggregatedProblem(
new IgnoredNaN(sourceColumn.getName(), i)); new IgnoredNaN(sourceColumn.getName(), i));
@ -83,13 +145,13 @@ public class AddRunning {
if (dNextValue == null) { if (dNextValue == null) {
isNothing.set(i); isNothing.set(i);
} else { } else {
result[i] = Double.doubleToRawLongBits(dNextValue); result[i] = typeHandler.typeToRawLongBits(dNextValue);
} }
} }
@Override @Override
public Storage<Double> getResult() { public Storage<T> getResult() {
return new DoubleStorage(result, sourceColumn.getSize(), isNothing); return typeHandler.createStorage(result, sourceColumn.getSize(), isNothing);
} }
} }
@ -127,10 +189,10 @@ public class AddRunning {
} }
} }
private static class RunningSumStatistic extends RunningStatisticBase { private static class RunningSumStatistic extends RunningStatisticBase<Double> {
RunningSumStatistic(Column sourceColumn, ProblemAggregator problemAggregator) { RunningSumStatistic(Column sourceColumn, ProblemAggregator problemAggregator) {
super(sourceColumn, problemAggregator); super(sourceColumn, problemAggregator, new DoubleHandler());
} }
@Override @Override
@ -147,10 +209,10 @@ public class AddRunning {
} }
} }
private static class RunningMeanStatistic extends RunningStatisticBase { private static class RunningMeanStatistic extends RunningStatisticBase<Double> {
RunningMeanStatistic(Column sourceColumn, ProblemAggregator problemAggregator) { RunningMeanStatistic(Column sourceColumn, ProblemAggregator problemAggregator) {
super(sourceColumn, problemAggregator); super(sourceColumn, problemAggregator, new DoubleHandler());
} }
@Override @Override
@ -181,10 +243,10 @@ public class AddRunning {
} }
} }
private static class RunningMinStatistic extends RunningStatisticBase { private static class RunningMinStatistic extends RunningStatisticBase<Double> {
RunningMinStatistic(Column sourceColumn, ProblemAggregator problemAggregator) { RunningMinStatistic(Column sourceColumn, ProblemAggregator problemAggregator) {
super(sourceColumn, problemAggregator); super(sourceColumn, problemAggregator, new DoubleHandler());
} }
@Override @Override
@ -201,10 +263,31 @@ public class AddRunning {
} }
} }
private static class RunningMaxStatistic extends RunningStatisticBase { private static class RunningMinLongStatistic extends RunningStatisticBase<Long> {
RunningMinLongStatistic(
Column sourceColumn, ProblemAggregator problemAggregator, IntegerType type) {
super(sourceColumn, problemAggregator, new LongHandler(type));
}
@Override
public RunningIterator<Long> getNewIterator() {
return new RunningMinLongIterator();
}
private static class RunningMinLongIterator extends RunningIteratorLong {
@Override
public void increment(long value) {
current = Math.min(current, value);
}
}
}
private static class RunningMaxStatistic extends RunningStatisticBase<Double> {
RunningMaxStatistic(Column sourceColumn, ProblemAggregator problemAggregator) { RunningMaxStatistic(Column sourceColumn, ProblemAggregator problemAggregator) {
super(sourceColumn, problemAggregator); super(sourceColumn, problemAggregator, new DoubleHandler());
} }
@Override @Override
@ -220,4 +303,59 @@ public class AddRunning {
} }
} }
} }
private static class RunningMaxLongStatistic extends RunningStatisticBase<Long> {
RunningMaxLongStatistic(
Column sourceColumn, ProblemAggregator problemAggregator, IntegerType type) {
super(sourceColumn, problemAggregator, new LongHandler(type));
}
@Override
public RunningIterator<Long> getNewIterator() {
return new RunningMaxLongIterator();
}
private static class RunningMaxLongIterator extends RunningIteratorLong {
@Override
public void increment(long value) {
current = Math.max(current, value);
}
}
}
private abstract static class RunningIteratorLong implements RunningIterator<Long> {
protected long current;
private boolean isInitialized = false;
@Override
public Long next(Long value) {
if (value != null) {
if (!isInitialized) {
isInitialized = true;
initialize(value);
} else {
increment(value);
}
}
return isInitialized ? getCurrent() : null;
}
@Override
public Long currentValue() {
return isInitialized ? getCurrent() : null;
}
protected void initialize(long value) {
current = value;
}
protected abstract void increment(long value);
protected long getCurrent() {
return current;
}
}
} }

View File

@ -1,5 +1,5 @@
from Standard.Base import all from Standard.Base import all
from Standard.Table import Column, Table from Standard.Table import Column, Table, Bits, Value_Type
from Standard.Test import all from Standard.Test import all
from Standard.Table.Errors import all from Standard.Table.Errors import all
import Standard.Base.Errors.Common.Type_Error import Standard.Base.Errors.Common.Type_Error
@ -15,16 +15,18 @@ type Data
# 2 | SG0456 | A | 73.23 # 2 | SG0456 | A | 73.23
# 3 | BA0123 | C | 112.34 # 3 | BA0123 | C | 112.34
# 4 | SG0456 | E | 73.77 # 4 | SG0456 | E | 73.77
Value ~table Value ~table ~integer_table
setup = setup =
flight = ["Flight", ["BA0123", "BA0123", "SG0456", "BA0123", "SG0456"]]
passenger = ["Passenger", ["A", "B", "A", "C", "E"]]
make_table = make_table =
flight = ["Flight", ["BA0123", "BA0123", "SG0456", "BA0123", "SG0456"]]
passenger = ["Passenger", ["A", "B", "A", "C", "E"]]
ticket_price = ["Ticket Price", [100.50, 575.99, 73.23, 112.34, 73.77]] ticket_price = ["Ticket Price", [100.50, 575.99, 73.23, 112.34, 73.77]]
Table.new [flight, passenger, ticket_price] Table.new [flight, passenger, ticket_price]
Data.Value make_table make_integer_table =
ticket_price = ["Ticket Price", [101, 576, 73, 112, 74]]
Table.new [flight, passenger, ticket_price]
Data.Value make_table make_integer_table
add_specs suite_builder = add_specs suite_builder =
suite_builder.group "running count" group_builder-> suite_builder.group "running count" group_builder->
@ -142,7 +144,7 @@ add_specs suite_builder =
group_builder.specify "Can provide running sum based on order by without grouping" <| group_builder.specify "Can provide running sum based on order by without grouping" <|
result = data.table.running Statistic.Sum "Ticket Price" "Sum ticket cost" [] ["Ticket Price"] result = data.table.running Statistic.Sum "Ticket Price" "Sum ticket cost" [] ["Ticket Price"]
expected_column = Column.from_vector "Sum ticket cost" [247.5, 935.83, 73.23, 359.84000000000003, 147] expected_column = Column.from_vector "Sum ticket cost" [247.5, 935.83, 73.23, 359.84000000000003, 147]
# | Flight | Passenger | Ticket Price | Ranked ticket cost # | Flight | Passenger | Ticket Price | Sum ticket cost
#---+--------+-----------+--------------+------------------------- #---+--------+-----------+--------------+-------------------------
# 0 | BA0123 | A | 100.5 | 3 # 0 | BA0123 | A | 100.5 | 3
# 1 | BA0123 | B | 575.99 | 5 # 1 | BA0123 | B | 575.99 | 5
@ -151,6 +153,18 @@ add_specs suite_builder =
# 4 | SG0456 | E | 73.77 | 2 # 4 | SG0456 | E | 73.77 | 2
expected_table = data.table.zip expected_column expected_table = data.table.zip expected_column
result.should_equal expected_table result.should_equal expected_table
group_builder.specify "Can provide running sum of integer columns (returning column of floats)" <|
result = data.integer_table.running Statistic.Sum "Ticket Price" "Sum ticket cost"
expected_column = Column.from_vector "Sum ticket cost" [101.0, 677.0, 750.0, 862.0, 936.0]
# | Flight | Passenger | Ticket Price | Sum ticket cost
#---+--------+-----------+--------------+-------------------------
# 0 | BA0123 | A | 101 | 101.0
# 1 | BA0123 | B | 576 | 677.0
# 2 | SG0456 | A | 73 | 750.0
# 3 | BA0123 | C | 112 | 862.0
# 4 | SG0456 | E | 74 | 936.0
expected_table = data.integer_table.zip expected_column
result.should_equal expected_table
suite_builder.group "running mean" group_builder-> suite_builder.group "running mean" group_builder->
data = Data.setup data = Data.setup
group_builder.specify "Not setting the as name gives default name based on of column" <| group_builder.specify "Not setting the as name gives default name based on of column" <|
@ -165,6 +179,18 @@ add_specs suite_builder =
# 4 | SG0456 | E | 73.77 | 187.166 # 4 | SG0456 | E | 73.77 | 187.166
expected_table = data.table.zip expected_column expected_table = data.table.zip expected_column
result.should_equal expected_table result.should_equal expected_table
group_builder.specify "Can provide running mean of integer columns (returning column of floats)" <|
result = data.integer_table.running Statistic.Mean "Ticket Price" "Mean ticket cost"
expected_column = Column.from_vector "Mean ticket cost" [101.0, 338.5, 250, 215.5, 187.2]
# | Flight | Passenger | Ticket Price | Mean ticket cost
#---+--------+-----------+--------------+-------------------------
# 0 | BA0123 | A | 101 | 101.0
# 1 | BA0123 | B | 576 | 338.5
# 2 | SG0456 | A | 73 | 250
# 3 | BA0123 | C | 112 | 215.5
# 4 | SG0456 | E | 74 | 187.2
expected_table = data.integer_table.zip expected_column
result.should_equal expected_table
suite_builder.group "running max" group_builder-> suite_builder.group "running max" group_builder->
data = Data.setup data = Data.setup
group_builder.specify "Not setting the as name gives default name based on of column" <| group_builder.specify "Not setting the as name gives default name based on of column" <|
@ -179,6 +205,32 @@ add_specs suite_builder =
# 4 | SG0456 | E | 73.77 | 575.99 # 4 | SG0456 | E | 73.77 | 575.99
expected_table = data.table.zip expected_column expected_table = data.table.zip expected_column
result.should_equal expected_table result.should_equal expected_table
group_builder.specify "Can provide running max of integer columns (returning column of integers)" <|
result = data.integer_table.running Statistic.Maximum "Ticket Price" "Max ticket cost"
expected_column = Column.from_vector "Max ticket cost" [101, 576, 576, 576, 576]
# | Flight | Passenger | Ticket Price | Max ticket cost
#---+--------+-----------+--------------+-------------------------
# 0 | BA0123 | A | 101 | 101
# 1 | BA0123 | B | 576 | 576
# 2 | SG0456 | A | 73 | 576
# 3 | BA0123 | C | 112 | 576
# 4 | SG0456 | E | 74 | 576
expected_table = data.integer_table.zip expected_column
result.should_equal expected_table
group_builder.specify "Can provide running max of int 16 columns (returning column of int 16)" <|
int16_col = data.integer_table.at "Ticket Price" . cast (Value_Type.Integer Bits.Bits_16)
int16_table = data.integer_table.set int16_col "Ticket Price"
result = int16_table.running Statistic.Maximum "Ticket Price" "Maximum ticket cost"
expected_column = Column.from_vector "Maximum ticket cost" [101, 576, 576, 576, 576] . cast (Value_Type.Integer Bits.Bits_16)
# | Flight | Passenger | Ticket Price | Maximum ticket cost
#---+--------+-----------+--------------+-------------------------
# 0 | BA0123 | A | 101 | 101
# 1 | BA0123 | B | 576 | 576
# 2 | SG0456 | A | 73 | 576
# 3 | BA0123 | C | 112 | 576
# 4 | SG0456 | E | 74 | 576
expected_table = int16_table.zip expected_column
result.should_equal expected_table
suite_builder.group "running min" group_builder-> suite_builder.group "running min" group_builder->
data = Data.setup data = Data.setup
group_builder.specify "Not setting the as name gives default name based on of column" <| group_builder.specify "Not setting the as name gives default name based on of column" <|
@ -193,6 +245,32 @@ add_specs suite_builder =
# 4 | SG0456 | E | 73.77 | 73.23 # 4 | SG0456 | E | 73.77 | 73.23
expected_table = data.table.zip expected_column expected_table = data.table.zip expected_column
result.should_equal expected_table result.should_equal expected_table
group_builder.specify "Can provide running mmin of integer columns (returning column of integers)" <|
result = data.integer_table.running Statistic.Minimum "Ticket Price" "Min ticket cost"
expected_column = Column.from_vector "Min ticket cost" [101, 101, 73, 73, 73]
# | Flight | Passenger | Ticket Price | Min ticket cost
#---+--------+-----------+--------------+-------------------------
# 0 | BA0123 | A | 101 | 101
# 1 | BA0123 | B | 576 | 101
# 2 | SG0456 | A | 73 | 73
# 3 | BA0123 | C | 112 | 73
# 4 | SG0456 | E | 74 | 73
expected_table = data.integer_table.zip expected_column
result.should_equal expected_table
group_builder.specify "Can provide running min of int 16 columns (returning column of int 16)" <|
int16_col = data.integer_table.at "Ticket Price" . cast (Value_Type.Integer Bits.Bits_16)
int16_table = data.integer_table.set int16_col "Ticket Price"
result = int16_table.running Statistic.Minimum "Ticket Price" "Min ticket cost"
expected_column = Column.from_vector "Min ticket cost" [101, 101, 73, 73, 73] . cast (Value_Type.Integer Bits.Bits_16)
# | Flight | Passenger | Ticket Price | Min ticket cost
#---+--------+-----------+--------------+-------------------------
# 0 | BA0123 | A | 101 | 101
# 1 | BA0123 | B | 576 | 101
# 2 | SG0456 | A | 73 | 73
# 3 | BA0123 | C | 112 | 73
# 4 | SG0456 | E | 74 | 73
expected_table = int16_table.zip expected_column
result.should_equal expected_table
suite_builder.group "nothing handling" group_builder-> suite_builder.group "nothing handling" group_builder->
# | Flight | Passenger | Ticket Price # | Flight | Passenger | Ticket Price
#---+--------+-----------+-------------- #---+--------+-----------+--------------