mirror of
https://github.com/enso-org/enso.git
synced 2024-11-22 22:10:15 +03:00
Make table.Running return integer typed columns for min/max (#9853)
* New Tests * Green * Running min for longs * Unsupported types test * Revert * Add support for all the integer types * Another test
This commit is contained in:
parent
930f3c593e
commit
15976a8505
@ -5,6 +5,8 @@ import org.enso.base.polyglot.NumericConverter;
|
||||
import org.enso.base.statistics.Statistic;
|
||||
import org.enso.table.data.column.storage.Storage;
|
||||
import org.enso.table.data.column.storage.numeric.DoubleStorage;
|
||||
import org.enso.table.data.column.storage.numeric.LongStorage;
|
||||
import org.enso.table.data.column.storage.type.IntegerType;
|
||||
import org.enso.table.data.table.Column;
|
||||
import org.enso.table.data.table.problems.IgnoredNaN;
|
||||
import org.enso.table.data.table.problems.IgnoredNothing;
|
||||
@ -31,7 +33,7 @@ public class AddRunning {
|
||||
return runningStatistic.getResult();
|
||||
}
|
||||
|
||||
private static RunningStatistic<Double> createRunningStatistic(
|
||||
private static RunningStatistic<?> createRunningStatistic(
|
||||
Statistic statistic, Column sourceColumn, ProblemAggregator problemAggregator) {
|
||||
switch (statistic) {
|
||||
case Sum -> {
|
||||
@ -41,38 +43,98 @@ public class AddRunning {
|
||||
return new RunningMeanStatistic(sourceColumn, problemAggregator);
|
||||
}
|
||||
case Minimum -> {
|
||||
if (sourceColumn.getStorage().getType() instanceof IntegerType type) {
|
||||
return new RunningMinLongStatistic(sourceColumn, problemAggregator, type);
|
||||
}
|
||||
return new RunningMinStatistic(sourceColumn, problemAggregator);
|
||||
}
|
||||
case Maximum -> {
|
||||
if (sourceColumn.getStorage().getType() instanceof IntegerType type) {
|
||||
return new RunningMaxLongStatistic(sourceColumn, problemAggregator, type);
|
||||
}
|
||||
return new RunningMaxStatistic(sourceColumn, problemAggregator);
|
||||
}
|
||||
default -> throw new IllegalArgumentException("Unsupported statistic: " + statistic);
|
||||
}
|
||||
}
|
||||
|
||||
private abstract static class RunningStatisticBase implements RunningStatistic<Double> {
|
||||
private interface TypeHandler<T> {
|
||||
|
||||
T tryConvertingToType(Object o);
|
||||
|
||||
long typeToRawLongBits(T t);
|
||||
|
||||
Storage<T> createStorage(long[] result, int size, BitSet isNothing);
|
||||
}
|
||||
|
||||
private static class DoubleHandler implements TypeHandler<Double> {
|
||||
|
||||
@Override
|
||||
public Double tryConvertingToType(Object o) {
|
||||
return NumericConverter.tryConvertingToDouble(o);
|
||||
}
|
||||
|
||||
@Override
|
||||
public long typeToRawLongBits(Double d) {
|
||||
return Double.doubleToRawLongBits(d);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Storage<Double> createStorage(long[] result, int size, BitSet isNothing) {
|
||||
return new DoubleStorage(result, size, isNothing);
|
||||
}
|
||||
}
|
||||
|
||||
private static class LongHandler implements TypeHandler<Long> {
|
||||
|
||||
IntegerType type;
|
||||
|
||||
LongHandler(IntegerType type) {
|
||||
this.type = type;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Long tryConvertingToType(Object o) {
|
||||
return NumericConverter.tryConvertingToLong(o);
|
||||
}
|
||||
|
||||
@Override
|
||||
public long typeToRawLongBits(Long l) {
|
||||
return l;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Storage<Long> createStorage(long[] result, int size, BitSet isNothing) {
|
||||
return new LongStorage(result, size, isNothing, type);
|
||||
}
|
||||
}
|
||||
|
||||
private abstract static class RunningStatisticBase<T> implements RunningStatistic<T> {
|
||||
|
||||
long[] result;
|
||||
BitSet isNothing;
|
||||
ColumnAggregatedProblemAggregator columnAggregatedProblemAggregator;
|
||||
Column sourceColumn;
|
||||
TypeHandler<T> typeHandler;
|
||||
|
||||
RunningStatisticBase(Column sourceColumn, ProblemAggregator problemAggregator) {
|
||||
RunningStatisticBase(
|
||||
Column sourceColumn, ProblemAggregator problemAggregator, TypeHandler<T> typeHandler) {
|
||||
result = new long[sourceColumn.getSize()];
|
||||
isNothing = new BitSet();
|
||||
columnAggregatedProblemAggregator = new ColumnAggregatedProblemAggregator(problemAggregator);
|
||||
this.sourceColumn = sourceColumn;
|
||||
this.typeHandler = typeHandler;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void calculateNextValue(int i, RunningIterator<Double> it) {
|
||||
public void calculateNextValue(int i, RunningIterator<T> it) {
|
||||
Object value = sourceColumn.getStorage().getItemBoxed(i);
|
||||
if (value == null) {
|
||||
columnAggregatedProblemAggregator.reportColumnAggregatedProblem(
|
||||
new IgnoredNothing(sourceColumn.getName(), i));
|
||||
}
|
||||
Double dValue = NumericConverter.tryConvertingToDouble(value);
|
||||
Double dNextValue;
|
||||
T dValue = typeHandler.tryConvertingToType(value);
|
||||
T dNextValue;
|
||||
if (dValue != null && dValue.equals(Double.NaN)) {
|
||||
columnAggregatedProblemAggregator.reportColumnAggregatedProblem(
|
||||
new IgnoredNaN(sourceColumn.getName(), i));
|
||||
@ -83,13 +145,13 @@ public class AddRunning {
|
||||
if (dNextValue == null) {
|
||||
isNothing.set(i);
|
||||
} else {
|
||||
result[i] = Double.doubleToRawLongBits(dNextValue);
|
||||
result[i] = typeHandler.typeToRawLongBits(dNextValue);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Storage<Double> getResult() {
|
||||
return new DoubleStorage(result, sourceColumn.getSize(), isNothing);
|
||||
public Storage<T> getResult() {
|
||||
return typeHandler.createStorage(result, sourceColumn.getSize(), isNothing);
|
||||
}
|
||||
}
|
||||
|
||||
@ -127,10 +189,10 @@ public class AddRunning {
|
||||
}
|
||||
}
|
||||
|
||||
private static class RunningSumStatistic extends RunningStatisticBase {
|
||||
private static class RunningSumStatistic extends RunningStatisticBase<Double> {
|
||||
|
||||
RunningSumStatistic(Column sourceColumn, ProblemAggregator problemAggregator) {
|
||||
super(sourceColumn, problemAggregator);
|
||||
super(sourceColumn, problemAggregator, new DoubleHandler());
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -147,10 +209,10 @@ public class AddRunning {
|
||||
}
|
||||
}
|
||||
|
||||
private static class RunningMeanStatistic extends RunningStatisticBase {
|
||||
private static class RunningMeanStatistic extends RunningStatisticBase<Double> {
|
||||
|
||||
RunningMeanStatistic(Column sourceColumn, ProblemAggregator problemAggregator) {
|
||||
super(sourceColumn, problemAggregator);
|
||||
super(sourceColumn, problemAggregator, new DoubleHandler());
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -181,10 +243,10 @@ public class AddRunning {
|
||||
}
|
||||
}
|
||||
|
||||
private static class RunningMinStatistic extends RunningStatisticBase {
|
||||
private static class RunningMinStatistic extends RunningStatisticBase<Double> {
|
||||
|
||||
RunningMinStatistic(Column sourceColumn, ProblemAggregator problemAggregator) {
|
||||
super(sourceColumn, problemAggregator);
|
||||
super(sourceColumn, problemAggregator, new DoubleHandler());
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -201,10 +263,31 @@ public class AddRunning {
|
||||
}
|
||||
}
|
||||
|
||||
private static class RunningMaxStatistic extends RunningStatisticBase {
|
||||
private static class RunningMinLongStatistic extends RunningStatisticBase<Long> {
|
||||
|
||||
RunningMinLongStatistic(
|
||||
Column sourceColumn, ProblemAggregator problemAggregator, IntegerType type) {
|
||||
super(sourceColumn, problemAggregator, new LongHandler(type));
|
||||
}
|
||||
|
||||
@Override
|
||||
public RunningIterator<Long> getNewIterator() {
|
||||
return new RunningMinLongIterator();
|
||||
}
|
||||
|
||||
private static class RunningMinLongIterator extends RunningIteratorLong {
|
||||
|
||||
@Override
|
||||
public void increment(long value) {
|
||||
current = Math.min(current, value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static class RunningMaxStatistic extends RunningStatisticBase<Double> {
|
||||
|
||||
RunningMaxStatistic(Column sourceColumn, ProblemAggregator problemAggregator) {
|
||||
super(sourceColumn, problemAggregator);
|
||||
super(sourceColumn, problemAggregator, new DoubleHandler());
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -220,4 +303,59 @@ public class AddRunning {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static class RunningMaxLongStatistic extends RunningStatisticBase<Long> {
|
||||
|
||||
RunningMaxLongStatistic(
|
||||
Column sourceColumn, ProblemAggregator problemAggregator, IntegerType type) {
|
||||
super(sourceColumn, problemAggregator, new LongHandler(type));
|
||||
}
|
||||
|
||||
@Override
|
||||
public RunningIterator<Long> getNewIterator() {
|
||||
return new RunningMaxLongIterator();
|
||||
}
|
||||
|
||||
private static class RunningMaxLongIterator extends RunningIteratorLong {
|
||||
|
||||
@Override
|
||||
public void increment(long value) {
|
||||
current = Math.max(current, value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private abstract static class RunningIteratorLong implements RunningIterator<Long> {
|
||||
|
||||
protected long current;
|
||||
private boolean isInitialized = false;
|
||||
|
||||
@Override
|
||||
public Long next(Long value) {
|
||||
if (value != null) {
|
||||
if (!isInitialized) {
|
||||
isInitialized = true;
|
||||
initialize(value);
|
||||
} else {
|
||||
increment(value);
|
||||
}
|
||||
}
|
||||
return isInitialized ? getCurrent() : null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Long currentValue() {
|
||||
return isInitialized ? getCurrent() : null;
|
||||
}
|
||||
|
||||
protected void initialize(long value) {
|
||||
current = value;
|
||||
}
|
||||
|
||||
protected abstract void increment(long value);
|
||||
|
||||
protected long getCurrent() {
|
||||
return current;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
from Standard.Base import all
|
||||
from Standard.Table import Column, Table
|
||||
from Standard.Table import Column, Table, Bits, Value_Type
|
||||
from Standard.Test import all
|
||||
from Standard.Table.Errors import all
|
||||
import Standard.Base.Errors.Common.Type_Error
|
||||
@ -15,16 +15,18 @@ type Data
|
||||
# 2 | SG0456 | A | 73.23
|
||||
# 3 | BA0123 | C | 112.34
|
||||
# 4 | SG0456 | E | 73.77
|
||||
Value ~table
|
||||
Value ~table ~integer_table
|
||||
|
||||
setup =
|
||||
flight = ["Flight", ["BA0123", "BA0123", "SG0456", "BA0123", "SG0456"]]
|
||||
passenger = ["Passenger", ["A", "B", "A", "C", "E"]]
|
||||
make_table =
|
||||
flight = ["Flight", ["BA0123", "BA0123", "SG0456", "BA0123", "SG0456"]]
|
||||
passenger = ["Passenger", ["A", "B", "A", "C", "E"]]
|
||||
ticket_price = ["Ticket Price", [100.50, 575.99, 73.23, 112.34, 73.77]]
|
||||
|
||||
Table.new [flight, passenger, ticket_price]
|
||||
Data.Value make_table
|
||||
make_integer_table =
|
||||
ticket_price = ["Ticket Price", [101, 576, 73, 112, 74]]
|
||||
Table.new [flight, passenger, ticket_price]
|
||||
Data.Value make_table make_integer_table
|
||||
|
||||
add_specs suite_builder =
|
||||
suite_builder.group "running count" group_builder->
|
||||
@ -142,7 +144,7 @@ add_specs suite_builder =
|
||||
group_builder.specify "Can provide running sum based on order by without grouping" <|
|
||||
result = data.table.running Statistic.Sum "Ticket Price" "Sum ticket cost" [] ["Ticket Price"]
|
||||
expected_column = Column.from_vector "Sum ticket cost" [247.5, 935.83, 73.23, 359.84000000000003, 147]
|
||||
# | Flight | Passenger | Ticket Price | Ranked ticket cost
|
||||
# | Flight | Passenger | Ticket Price | Sum ticket cost
|
||||
#---+--------+-----------+--------------+-------------------------
|
||||
# 0 | BA0123 | A | 100.5 | 3
|
||||
# 1 | BA0123 | B | 575.99 | 5
|
||||
@ -151,6 +153,18 @@ add_specs suite_builder =
|
||||
# 4 | SG0456 | E | 73.77 | 2
|
||||
expected_table = data.table.zip expected_column
|
||||
result.should_equal expected_table
|
||||
group_builder.specify "Can provide running sum of integer columns (returning column of floats)" <|
|
||||
result = data.integer_table.running Statistic.Sum "Ticket Price" "Sum ticket cost"
|
||||
expected_column = Column.from_vector "Sum ticket cost" [101.0, 677.0, 750.0, 862.0, 936.0]
|
||||
# | Flight | Passenger | Ticket Price | Sum ticket cost
|
||||
#---+--------+-----------+--------------+-------------------------
|
||||
# 0 | BA0123 | A | 101 | 101.0
|
||||
# 1 | BA0123 | B | 576 | 677.0
|
||||
# 2 | SG0456 | A | 73 | 750.0
|
||||
# 3 | BA0123 | C | 112 | 862.0
|
||||
# 4 | SG0456 | E | 74 | 936.0
|
||||
expected_table = data.integer_table.zip expected_column
|
||||
result.should_equal expected_table
|
||||
suite_builder.group "running mean" group_builder->
|
||||
data = Data.setup
|
||||
group_builder.specify "Not setting the as name gives default name based on of column" <|
|
||||
@ -165,6 +179,18 @@ add_specs suite_builder =
|
||||
# 4 | SG0456 | E | 73.77 | 187.166
|
||||
expected_table = data.table.zip expected_column
|
||||
result.should_equal expected_table
|
||||
group_builder.specify "Can provide running mean of integer columns (returning column of floats)" <|
|
||||
result = data.integer_table.running Statistic.Mean "Ticket Price" "Mean ticket cost"
|
||||
expected_column = Column.from_vector "Mean ticket cost" [101.0, 338.5, 250, 215.5, 187.2]
|
||||
# | Flight | Passenger | Ticket Price | Mean ticket cost
|
||||
#---+--------+-----------+--------------+-------------------------
|
||||
# 0 | BA0123 | A | 101 | 101.0
|
||||
# 1 | BA0123 | B | 576 | 338.5
|
||||
# 2 | SG0456 | A | 73 | 250
|
||||
# 3 | BA0123 | C | 112 | 215.5
|
||||
# 4 | SG0456 | E | 74 | 187.2
|
||||
expected_table = data.integer_table.zip expected_column
|
||||
result.should_equal expected_table
|
||||
suite_builder.group "running max" group_builder->
|
||||
data = Data.setup
|
||||
group_builder.specify "Not setting the as name gives default name based on of column" <|
|
||||
@ -179,6 +205,32 @@ add_specs suite_builder =
|
||||
# 4 | SG0456 | E | 73.77 | 575.99
|
||||
expected_table = data.table.zip expected_column
|
||||
result.should_equal expected_table
|
||||
group_builder.specify "Can provide running max of integer columns (returning column of integers)" <|
|
||||
result = data.integer_table.running Statistic.Maximum "Ticket Price" "Max ticket cost"
|
||||
expected_column = Column.from_vector "Max ticket cost" [101, 576, 576, 576, 576]
|
||||
# | Flight | Passenger | Ticket Price | Max ticket cost
|
||||
#---+--------+-----------+--------------+-------------------------
|
||||
# 0 | BA0123 | A | 101 | 101
|
||||
# 1 | BA0123 | B | 576 | 576
|
||||
# 2 | SG0456 | A | 73 | 576
|
||||
# 3 | BA0123 | C | 112 | 576
|
||||
# 4 | SG0456 | E | 74 | 576
|
||||
expected_table = data.integer_table.zip expected_column
|
||||
result.should_equal expected_table
|
||||
group_builder.specify "Can provide running max of int 16 columns (returning column of int 16)" <|
|
||||
int16_col = data.integer_table.at "Ticket Price" . cast (Value_Type.Integer Bits.Bits_16)
|
||||
int16_table = data.integer_table.set int16_col "Ticket Price"
|
||||
result = int16_table.running Statistic.Maximum "Ticket Price" "Maximum ticket cost"
|
||||
expected_column = Column.from_vector "Maximum ticket cost" [101, 576, 576, 576, 576] . cast (Value_Type.Integer Bits.Bits_16)
|
||||
# | Flight | Passenger | Ticket Price | Maximum ticket cost
|
||||
#---+--------+-----------+--------------+-------------------------
|
||||
# 0 | BA0123 | A | 101 | 101
|
||||
# 1 | BA0123 | B | 576 | 576
|
||||
# 2 | SG0456 | A | 73 | 576
|
||||
# 3 | BA0123 | C | 112 | 576
|
||||
# 4 | SG0456 | E | 74 | 576
|
||||
expected_table = int16_table.zip expected_column
|
||||
result.should_equal expected_table
|
||||
suite_builder.group "running min" group_builder->
|
||||
data = Data.setup
|
||||
group_builder.specify "Not setting the as name gives default name based on of column" <|
|
||||
@ -193,6 +245,32 @@ add_specs suite_builder =
|
||||
# 4 | SG0456 | E | 73.77 | 73.23
|
||||
expected_table = data.table.zip expected_column
|
||||
result.should_equal expected_table
|
||||
group_builder.specify "Can provide running mmin of integer columns (returning column of integers)" <|
|
||||
result = data.integer_table.running Statistic.Minimum "Ticket Price" "Min ticket cost"
|
||||
expected_column = Column.from_vector "Min ticket cost" [101, 101, 73, 73, 73]
|
||||
# | Flight | Passenger | Ticket Price | Min ticket cost
|
||||
#---+--------+-----------+--------------+-------------------------
|
||||
# 0 | BA0123 | A | 101 | 101
|
||||
# 1 | BA0123 | B | 576 | 101
|
||||
# 2 | SG0456 | A | 73 | 73
|
||||
# 3 | BA0123 | C | 112 | 73
|
||||
# 4 | SG0456 | E | 74 | 73
|
||||
expected_table = data.integer_table.zip expected_column
|
||||
result.should_equal expected_table
|
||||
group_builder.specify "Can provide running min of int 16 columns (returning column of int 16)" <|
|
||||
int16_col = data.integer_table.at "Ticket Price" . cast (Value_Type.Integer Bits.Bits_16)
|
||||
int16_table = data.integer_table.set int16_col "Ticket Price"
|
||||
result = int16_table.running Statistic.Minimum "Ticket Price" "Min ticket cost"
|
||||
expected_column = Column.from_vector "Min ticket cost" [101, 101, 73, 73, 73] . cast (Value_Type.Integer Bits.Bits_16)
|
||||
# | Flight | Passenger | Ticket Price | Min ticket cost
|
||||
#---+--------+-----------+--------------+-------------------------
|
||||
# 0 | BA0123 | A | 101 | 101
|
||||
# 1 | BA0123 | B | 576 | 101
|
||||
# 2 | SG0456 | A | 73 | 73
|
||||
# 3 | BA0123 | C | 112 | 73
|
||||
# 4 | SG0456 | E | 74 | 73
|
||||
expected_table = int16_table.zip expected_column
|
||||
result.should_equal expected_table
|
||||
suite_builder.group "nothing handling" group_builder->
|
||||
# | Flight | Passenger | Ticket Price
|
||||
#---+--------+-----------+--------------
|
||||
|
Loading…
Reference in New Issue
Block a user