mirror of
https://github.com/enso-org/enso.git
synced 2024-11-23 08:08:34 +03:00
Add Linear Regression support for Vectors. (#3601)
Adds least squares regression APIs. Covers the basic 4 trend line types from Excel (doesn't cover Polynomial or Moving Average). Removes the old `Model` from the `Standard.Table`.
This commit is contained in:
parent
5b4aac0138
commit
be311457bd
@ -166,6 +166,8 @@
|
||||
- [Fixed the case of various type names and library paths][3590]
|
||||
- [Added support for parsing `.pgpass` file and `PG*` environment variables for
|
||||
the Postgres connection][3593]
|
||||
- [Added `Regression` to the `Standard.Base` library and removed legacy `Model`
|
||||
type from `Standard.Table`.][3601]
|
||||
|
||||
[debug-shortcuts]:
|
||||
https://github.com/enso-org/enso/blob/develop/app/gui/docs/product/shortcuts.md#debug
|
||||
@ -263,6 +265,7 @@
|
||||
[3588]: https://github.com/enso-org/enso/pull/3588
|
||||
[3590]: https://github.com/enso-org/enso/pull/3590
|
||||
[3593]: https://github.com/enso-org/enso/pull/3593
|
||||
[3601]: https://github.com/enso-org/enso/pull/3601
|
||||
|
||||
#### Enso Compiler
|
||||
|
||||
|
@ -0,0 +1,108 @@
|
||||
from Standard.Base import all
|
||||
|
||||
polyglot java import org.enso.base.statistics.Regression
|
||||
polyglot java import org.enso.base.statistics.FitError
|
||||
|
||||
type Model
|
||||
## Fit a line (y = A x + B) to the data with an optional fixed intercept.
|
||||
type Linear_Model (intercept:Number|Nothing=Nothing)
|
||||
|
||||
## Fit a exponential line (y = A exp(B x)) to the data with an optional fixed intercept.
|
||||
type Exponential_Model (intercept:Number|Nothing=Nothing)
|
||||
|
||||
## Fit a logarithmic line (y = A log x + B) to the data.
|
||||
type Logarithmic_Model
|
||||
|
||||
## Fit a power series (y = A x ^ B) to the data.
|
||||
type Power_Model
|
||||
|
||||
## Use Least Squares to fit a line to the data.
|
||||
fit_least_squares : Vector -> Vector -> Model -> Fitted_Model ! Illegal_Argument_Error | Fit_Error
|
||||
fit_least_squares known_xs known_ys model=Linear_Model =
|
||||
Illegal_Argument_Error.handle_java_exception <| Fit_Error.handle_java_exception <| case model of
|
||||
Linear_Model intercept ->
|
||||
fitted = if intercept.is_nothing then Regression.fit_linear known_xs.to_array known_ys.to_array else
|
||||
Regression.fit_linear known_xs.to_array known_ys.to_array intercept
|
||||
Fitted_Linear_Model fitted.slope fitted.intercept fitted.rSquared
|
||||
Exponential_Model intercept ->
|
||||
log_ys = ln_series known_ys "Y-values"
|
||||
fitted = if intercept.is_nothing then Regression.fit_linear known_xs.to_array log_ys.to_array else
|
||||
Regression.fit_linear known_xs.to_array log_ys.to_array intercept.ln
|
||||
fitted_model_with_r_squared Fitted_Exponential_Model fitted.intercept.exp fitted.slope known_xs known_ys
|
||||
Logarithmic_Model ->
|
||||
log_xs = ln_series known_xs "X-values"
|
||||
fitted = Regression.fit_linear log_xs.to_array known_ys.to_array
|
||||
fitted_model_with_r_squared Fitted_Logarithmic_Model fitted.slope fitted.intercept known_xs known_ys
|
||||
Power_Model ->
|
||||
log_xs = ln_series known_xs "X-values"
|
||||
log_ys = ln_series known_ys "Y-values"
|
||||
fitted = Regression.fit_linear log_xs.to_array log_ys.to_array
|
||||
fitted_model_with_r_squared Fitted_Power_Model fitted.intercept.exp fitted.slope known_xs known_ys
|
||||
_ -> Error.throw (Illegal_Argument_Error "Unsupported model.")
|
||||
|
||||
type Fitted_Model
|
||||
## Fitted line (y = slope x + intercept).
|
||||
type Fitted_Linear_Model slope:Number intercept:Number r_squared:Number=0.0
|
||||
|
||||
## Fitted exponential line (y = a exp(b x)).
|
||||
type Fitted_Exponential_Model a:Number b:Number r_squared:Number=0.0
|
||||
|
||||
## Fitted logarithmic line (y = a log x + b).
|
||||
type Fitted_Logarithmic_Model a:Number b:Number r_squared:Number=0.0
|
||||
|
||||
## Fitted power series (y = a x ^ b).
|
||||
type Fitted_Power_Model a:Number b:Number r_squared:Number=0.0
|
||||
|
||||
## Display the fitted line.
|
||||
to_text : Text
|
||||
to_text =
|
||||
equation = case self of
|
||||
Fitted_Linear_Model slope intercept _ -> slope.to_text + " * X + " + intercept.to_text
|
||||
Fitted_Exponential_Model a b _ -> a.to_text + " * (" + b.to_text + " * X).exp"
|
||||
Fitted_Logarithmic_Model a b _ -> a.to_text + " * X.ln + " + b.to_text
|
||||
Fitted_Power_Model a b _ -> a.to_text + " * X ^ " + b.to_text
|
||||
"Fitted_Model(" + equation + ")"
|
||||
|
||||
## Use the model to predict a value.
|
||||
predict : Number -> Number
|
||||
predict x = case self of
|
||||
Fitted_Linear_Model slope intercept _ -> slope * x + intercept
|
||||
Fitted_Exponential_Model a b _ -> a * (b * x).exp
|
||||
Fitted_Logarithmic_Model a b _ -> a * x.ln + b
|
||||
Fitted_Power_Model a b _ -> a * (x ^ b)
|
||||
_ -> Error.throw (Illegal_Argument_Error "Unsupported model.")
|
||||
|
||||
## PRIVATE
|
||||
Computes the R Squared value for a model and returns a new instance.
|
||||
fitted_model_with_r_squared : Any -> Number -> Number -> Vector -> Vector -> Fitted_Model
|
||||
fitted_model_with_r_squared constructor a b known_xs known_ys =
|
||||
model = constructor a b
|
||||
r_squared = known_ys.compute (Statistics.R_Squared (known_xs.map model.predict))
|
||||
constructor a b r_squared
|
||||
|
||||
## PRIVATE
|
||||
|
||||
Computes the natural log series as long as all values are positive.
|
||||
ln_series : Vector -> Vector ! Illegal_Argument_Error
|
||||
ln_series xs series_name="Values" =
|
||||
ln_with_panic x = if x.is_nothing then Nothing else
|
||||
if x <= 0 then Panic.throw (Illegal_Argument_Error (series_name + " must be positive.")) else x.ln
|
||||
Panic.recover Illegal_Argument_Error <| xs.map ln_with_panic
|
||||
|
||||
## PRIVATE
|
||||
|
||||
An error thrown when the linear regression cannot be computed.
|
||||
|
||||
Arguments:
|
||||
- message: The error message.
|
||||
type Fit_Error message
|
||||
|
||||
## PRIVATE
|
||||
|
||||
Converts the `Fit_Error` to a human-readable representation.
|
||||
Fit_Error.to_display_text : Text
|
||||
Fit_Error.to_display_text = "Could not fit the model: " + self.message.to_text
|
||||
|
||||
## PRIVATE
|
||||
Fit_Error.handle_java_exception =
|
||||
Panic.catch_java FitError handler=(java_exception-> Error.throw (Fit_Error java_exception.getMessage))
|
@ -11,7 +11,6 @@ polyglot java import org.enso.base.statistics.CountMinMax
|
||||
polyglot java import org.enso.base.statistics.CorrelationStatistics
|
||||
polyglot java import org.enso.base.statistics.Rank
|
||||
|
||||
polyglot java import java.lang.IllegalArgumentException
|
||||
polyglot java import java.lang.ClassCastException
|
||||
polyglot java import java.lang.NullPointerException
|
||||
|
||||
@ -185,10 +184,7 @@ wrap_java_call ~function =
|
||||
report_unsupported _ = Error.throw (Illegal_Argument_Error ("Can only compute correlations on numerical data sets."))
|
||||
handle_unsupported = Panic.catch Unsupported_Argument_Types handler=report_unsupported
|
||||
|
||||
report_illegal caught_panic = Error.throw (Illegal_Argument_Error caught_panic.payload.cause.getMessage)
|
||||
handle_illegal = Panic.catch IllegalArgumentException handler=report_illegal
|
||||
|
||||
handle_unsupported <| handle_illegal <| function
|
||||
handle_unsupported <| Illegal_Argument_Error.handle_java_exception <| function
|
||||
|
||||
|
||||
## PRIVATE
|
||||
|
@ -2,6 +2,8 @@ from Standard.Base import all
|
||||
import Standard.Base.Data.Json
|
||||
import Standard.Base.Runtime
|
||||
|
||||
polyglot java import java.lang.IllegalArgumentException
|
||||
|
||||
## Dataflow errors.
|
||||
type Error
|
||||
|
||||
@ -191,6 +193,11 @@ type Illegal_Argument_Error
|
||||
- cause: (optional) another error that is the cause of this one.
|
||||
type Illegal_Argument_Error message cause=Nothing
|
||||
|
||||
## PRIVATE
|
||||
Capture a Java IllegalArgumentException and rethrow
|
||||
handle_java_exception =
|
||||
Panic.catch_java IllegalArgumentException handler=(cause-> Error.throw (Illegal_Argument_Error cause.getMessage cause))
|
||||
|
||||
## PRIVATE
|
||||
Wraps a dataflow error lifted to a panic, making possible to distinguish it
|
||||
from other panics.
|
||||
|
@ -34,6 +34,9 @@ import project.System.Environment
|
||||
import project.System.File
|
||||
import project.Data.Text.Regex.Mode as Regex_Mode
|
||||
import project.Warning
|
||||
import project.Data.Statistics
|
||||
import project.Data.Statistics.Rank_Method
|
||||
import project.Data.Regression
|
||||
|
||||
export project.Data.Interval
|
||||
export project.Data.Json
|
||||
@ -42,6 +45,9 @@ export project.Data.Map
|
||||
export project.Data.Maybe
|
||||
export project.Data.Ordering
|
||||
export project.Data.Ordering.Sort_Direction
|
||||
export project.Data.Regression
|
||||
export project.Data.Statistics
|
||||
export project.Data.Statistics.Rank_Method
|
||||
export project.Data.Vector
|
||||
export project.IO
|
||||
export project.Math
|
||||
|
@ -114,7 +114,7 @@ type Aggregate_Column
|
||||
- column: column (specified by name, index or Column object) to compute
|
||||
standard deviation.
|
||||
- name: name of new column.
|
||||
- population argument specifies if group is a sample or the population
|
||||
- population: specifies if group is a sample or the population
|
||||
type Standard_Deviation (column:Column|Text|Integer) (new_name:Text|Nothing=Nothing) (population:Boolean=False)
|
||||
|
||||
## Creates a new column with the values concatenated together. `Nothing`
|
||||
|
@ -18,7 +18,6 @@ polyglot java import org.enso.table.error.ExistingDataException
|
||||
polyglot java import org.enso.table.error.RangeExceededException
|
||||
polyglot java import org.enso.table.error.InvalidLocationException
|
||||
|
||||
polyglot java import java.lang.IllegalArgumentException
|
||||
polyglot java import java.lang.IllegalStateException
|
||||
polyglot java import java.io.IOException
|
||||
polyglot java import org.apache.poi.UnsupportedFileFormatException
|
||||
@ -114,8 +113,7 @@ type Excel_Range
|
||||
## Creates a Range from an address.
|
||||
from_address : Text -> Excel_Range
|
||||
from_address address =
|
||||
illegal_argument caught_panic = Error.throw (Illegal_Argument_Error caught_panic.payload.cause.getMessage caught_panic.payload.cause)
|
||||
Panic.catch IllegalArgumentException handler=illegal_argument <|
|
||||
Illegal_Argument_Error.handle_java_exception <|
|
||||
Excel_Range (Java_Range.new address)
|
||||
|
||||
## Create a Range for a single cell.
|
||||
@ -281,15 +279,11 @@ handle_writer ~writer =
|
||||
throw_existing_data caught_panic = Error.throw (Existing_Data caught_panic.payload.cause.getMessage)
|
||||
handle_existing_data = Panic.catch ExistingDataException handler=throw_existing_data
|
||||
|
||||
## Illegal argument can occur if appending in an invalid mode
|
||||
illegal_argument caught_panic = Error.throw (Illegal_Argument_Error caught_panic.payload.cause.getMessage caught_panic.payload.cause)
|
||||
handle_illegal_argument = Panic.catch IllegalArgumentException handler=illegal_argument
|
||||
|
||||
## Should be impossible - occurs if no fallback serializer is provided.
|
||||
throw_illegal_state caught_panic = Panic.throw (Illegal_State_Error caught_panic.payload.cause.getMessage)
|
||||
handle_illegal_state = Panic.catch IllegalStateException handler=throw_illegal_state
|
||||
|
||||
handle_illegal_state <| Column_Name_Mismatch.handle_java_exception <|
|
||||
Column_Count_Mismatch.handle_java_exception <| handle_bad_location <|
|
||||
handle_illegal_argument <| handle_range_exceeded <| handle_existing_data <|
|
||||
Illegal_Argument_Error.handle_java_exception <| handle_range_exceeded <| handle_existing_data <|
|
||||
writer
|
||||
|
@ -17,7 +17,6 @@ polyglot java import org.enso.table.parsing.problems.MismatchedQuote
|
||||
polyglot java import org.enso.table.parsing.problems.AdditionalInvalidRows
|
||||
polyglot java import org.enso.table.util.problems.DuplicateNames
|
||||
polyglot java import org.enso.table.util.problems.InvalidNames
|
||||
polyglot java import java.lang.IllegalArgumentException
|
||||
polyglot java import java.io.IOException
|
||||
polyglot java import com.univocity.parsers.common.TextParsingException
|
||||
polyglot java import org.enso.base.Encoding_Utils
|
||||
@ -91,7 +90,7 @@ read_stream format stream on_problems max_columns=default_max_columns related_fi
|
||||
integer.
|
||||
read_from_reader : Delimited -> Reader -> Problem_Behavior -> Integer -> Any
|
||||
read_from_reader format java_reader on_problems max_columns=4096 =
|
||||
handle_illegal_arguments <| handle_parsing_failure <| handle_parsing_exception <|
|
||||
Illegal_Argument_Error.handle_java_exception <| handle_parsing_failure <| handle_parsing_exception <|
|
||||
reader = prepare_delimited_reader java_reader format max_columns on_problems
|
||||
result_with_problems = reader.read
|
||||
parsing_problems = Vector.Vector (result_with_problems.problems) . map translate_reader_problem
|
||||
@ -160,7 +159,7 @@ type Detected_File_Metadata
|
||||
detect_metadata : File -> File_Format.Delimited -> Detected_Headers
|
||||
detect_metadata file format =
|
||||
on_problems = Ignore
|
||||
result = handle_io_exception file <| handle_illegal_arguments <| handle_parsing_failure <| handle_parsing_exception <|
|
||||
result = handle_io_exception file <| Illegal_Argument_Error.handle_java_exception <| handle_parsing_failure <| handle_parsing_exception <|
|
||||
file.with_input_stream [File.Option.Read] stream->
|
||||
stream.with_stream_decoder format.encoding on_problems java_reader->
|
||||
## We use the default `max_columns` setting. If we want to be able to
|
||||
@ -179,12 +178,6 @@ detect_metadata file format =
|
||||
Detected_File_Metadata headers line_separator
|
||||
result.catch File.File_Not_Found (_->(Detected_File_Metadata Nothing Nothing))
|
||||
|
||||
## PRIVATE
|
||||
handle_illegal_arguments =
|
||||
translate_illegal_argument caught_panic =
|
||||
Error.throw (Illegal_Argument_Error caught_panic.payload.cause.getMessage)
|
||||
Panic.catch IllegalArgumentException handler=translate_illegal_argument
|
||||
|
||||
## PRIVATE
|
||||
handle_parsing_failure =
|
||||
translate_parsing_failure caught_panic =
|
||||
|
@ -6,12 +6,10 @@ import Standard.Table.IO.File_Format
|
||||
import Standard.Table.IO.Excel
|
||||
import Standard.Table.Data.Table
|
||||
import Standard.Table.Data.Column
|
||||
import Standard.Table.Model
|
||||
|
||||
from Standard.Table.IO.Excel export Excel_Section, Excel_Range
|
||||
|
||||
export Standard.Table.Data.Column
|
||||
export Standard.Table.Model
|
||||
export Standard.Table.IO.File_Read
|
||||
export Standard.Table.IO.File_Format
|
||||
|
||||
|
@ -1,77 +0,0 @@
|
||||
from Standard.Base import all
|
||||
|
||||
from Standard.Table import Column, Table
|
||||
|
||||
## Compute the linear regression between the x and y coordinates, returning a
|
||||
vector containing the slope in the first position and the bias in the second.
|
||||
|
||||
Arguments:
|
||||
- x_column: The name of the column in `self` containing the x values.
|
||||
- y_column: The name of the column in `self` containing the y values.
|
||||
|
||||
If the columns don't match in length, it throws a `Fit_Error`.
|
||||
|
||||
> Example
|
||||
Compute the linear regression between two columns in a table.
|
||||
|
||||
from Standard.Table import all
|
||||
|
||||
example_linear_regression =
|
||||
column_x = Column.from_vector "x" [1, 2, 3, 4, 5]
|
||||
column_y = Column.from_vector "y" [2, 4, 6, 8, 10]
|
||||
table = Table.new [column_x, column_y]
|
||||
table.linear_regression "x" "y"
|
||||
Table.Table.linear_regression : Text -> Text -> Vector Number ! Fit_Error
|
||||
Table.Table.linear_regression x_column y_column =
|
||||
x_values = self.at x_column
|
||||
y_values = self.at y_column
|
||||
linear_regression x_values y_values
|
||||
|
||||
## Compute the linear regression between the x and y coordinates, returning a
|
||||
vector containing the slope in the first position and the bias in the second.
|
||||
|
||||
Arguments:
|
||||
- x_values: The column of x coordinate values for each coordinate pair.
|
||||
- y_values: The column of y coordinate values for each coordinate pair.
|
||||
|
||||
If the columns don't match in length, it throws a `Fit_Error`.
|
||||
|
||||
> Example
|
||||
Compute the linear regression between two columns.
|
||||
|
||||
from Standard.Table import all
|
||||
|
||||
example_linear_regression =
|
||||
column_x = Column.from_vector "x" [1, 2, 3, 4, 5]
|
||||
column_y = Column.from_vector "y" [2, 4, 6, 8, 10]
|
||||
Model.linear_regression column_x column_y
|
||||
linear_regression : Column -> Column -> Vector Number ! Fit_Error
|
||||
linear_regression x_values y_values =
|
||||
if x_values.length != y_values.length then Error.throw (Fit_Error "Columns have different lengths.") else
|
||||
n = x_values.length
|
||||
x_squared = x_values.map (^2)
|
||||
x_y = x_values * y_values
|
||||
|
||||
slope_numerator = (n * x_y.sum) - (x_values.sum * y_values.sum)
|
||||
slope_denominator = (n * x_squared.sum) - (x_values.sum ^ 2)
|
||||
slope = slope_numerator / slope_denominator
|
||||
|
||||
bias_numerator = y_values.sum - (slope * x_values.sum)
|
||||
bias = bias_numerator / n
|
||||
|
||||
[slope, bias]
|
||||
|
||||
## PRIVATE
|
||||
|
||||
An error thrown when the linear regression cannot be computed.
|
||||
|
||||
Arguments:
|
||||
- message: The error message.
|
||||
type Fit_Error message
|
||||
|
||||
## PRIVATE
|
||||
|
||||
Converts the `Fit_Error` to a human-readable representation.
|
||||
Fit_Error.to_display_text : Text
|
||||
Fit_Error.to_display_text = "Could not fit the model: " + self.message.to_text
|
||||
|
@ -22,6 +22,51 @@ public class CorrelationStatistics {
|
||||
totalXY += x * y;
|
||||
}
|
||||
|
||||
/*
|
||||
* Count of non-null pairs of values.
|
||||
*/
|
||||
public long getCount() {
|
||||
return count;
|
||||
}
|
||||
|
||||
/*
|
||||
* Sum of X values.
|
||||
*/
|
||||
public double getTotalX() {
|
||||
return totalX;
|
||||
}
|
||||
|
||||
/*
|
||||
* Sum of Y values.
|
||||
*/
|
||||
public double getTotalY() {
|
||||
return totalY;
|
||||
}
|
||||
|
||||
/*
|
||||
* Sum of X^2 values.
|
||||
*/
|
||||
public double getTotalXX() {
|
||||
return totalXX;
|
||||
}
|
||||
|
||||
/*
|
||||
* Sum of X * Y values.
|
||||
*/
|
||||
public double getTotalXY() {
|
||||
return totalXY;
|
||||
}
|
||||
|
||||
/*
|
||||
* Sum of Y^2 values.
|
||||
*/
|
||||
public double getTotalYY() {
|
||||
return totalYY;
|
||||
}
|
||||
|
||||
/*
|
||||
* Compute the covariance of X and Y.
|
||||
*/
|
||||
public double covariance() {
|
||||
if (count < 2) {
|
||||
return Double.NaN;
|
||||
@ -30,6 +75,9 @@ public class CorrelationStatistics {
|
||||
return (totalXY - totalX * totalY / count) / count;
|
||||
}
|
||||
|
||||
/*
|
||||
* Compute the Pearson correlation between X and Y.
|
||||
*/
|
||||
public double pearsonCorrelation() {
|
||||
if (count < 2) {
|
||||
return Double.NaN;
|
||||
@ -40,6 +88,9 @@ public class CorrelationStatistics {
|
||||
return (count * totalXY - totalX * totalY) / (n_stdev_x * n_stdev_y);
|
||||
}
|
||||
|
||||
/*
|
||||
* Compute the R-Squared between X and Y (which equals the Pearson correlation ^ 2).
|
||||
*/
|
||||
public double rSquared() {
|
||||
double correl = this.pearsonCorrelation();
|
||||
return correl * correl;
|
||||
|
@ -0,0 +1,10 @@
|
||||
package org.enso.base.statistics;
|
||||
|
||||
/*
|
||||
A class for exceptions thrown when fitting a model.
|
||||
*/
|
||||
public class FitError extends Exception {
|
||||
public FitError(String message) {
|
||||
super(message);
|
||||
}
|
||||
}
|
@ -0,0 +1,3 @@
|
||||
package org.enso.base.statistics;
|
||||
|
||||
public record LinearModel(double slope, double intercept, double rSquared) {}
|
@ -0,0 +1,58 @@
|
||||
package org.enso.base.statistics;
|
||||
|
||||
public class Regression {
|
||||
/**
|
||||
* Performs a least squares fit of a line to the data.
|
||||
*
|
||||
* @param known_xs Set of known X values.
|
||||
* @param known_ys Set of known Y values.
|
||||
* @return A fitted linear model (y = Intercept + Slope x) and the r-squared value.
|
||||
* @throws IllegalArgumentException if the number of elements in the arrays is different or a
|
||||
* singular X value is provided.
|
||||
*/
|
||||
public static LinearModel fit_linear(Double[] known_xs, Double[] known_ys)
|
||||
throws IllegalArgumentException, FitError {
|
||||
CorrelationStatistics stats = CorrelationStatistics.compute(known_xs, known_ys);
|
||||
|
||||
double denominator = denominator(stats);
|
||||
if (denominator == 0) {
|
||||
throw new FitError("Singular X value.");
|
||||
}
|
||||
|
||||
double slope = slope(stats, denominator);
|
||||
return new LinearModel(slope, intercept(stats, slope), stats.rSquared());
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs a least squares fit of a line to the data with a given intercept.
|
||||
*
|
||||
* @param known_xs Set of known X values.
|
||||
* @param known_ys Set of known Y values.
|
||||
* @param intercept The intercept of the line.
|
||||
* @return A fitted linear model (y = Intercept + Slope x) and the r-squared value.
|
||||
* @throws IllegalArgumentException if the number of elements in the arrays is different or a
|
||||
* singular X value is provided.
|
||||
*/
|
||||
public static LinearModel fit_linear(Double[] known_xs, Double[] known_ys, double intercept)
|
||||
throws IllegalArgumentException {
|
||||
CorrelationStatistics stats = CorrelationStatistics.compute(known_xs, known_ys);
|
||||
return new LinearModel(slopeWithIntercept(stats, intercept), intercept, stats.rSquared());
|
||||
}
|
||||
|
||||
private static double denominator(CorrelationStatistics stats) {
|
||||
return stats.getTotalXX() - stats.getTotalX() * stats.getTotalX() / stats.getCount();
|
||||
}
|
||||
|
||||
private static double slope(CorrelationStatistics stats, double denominator) {
|
||||
return (stats.getTotalXY() - stats.getTotalX() * stats.getTotalY() / stats.getCount())
|
||||
/ denominator;
|
||||
}
|
||||
|
||||
private static double slopeWithIntercept(CorrelationStatistics stats, double intercept) {
|
||||
return (-intercept * stats.getTotalX() + stats.getTotalXY()) / stats.getTotalXX();
|
||||
}
|
||||
|
||||
private static double intercept(CorrelationStatistics stats, double slope) {
|
||||
return (stats.getTotalY() - stats.getTotalX() * slope) / stats.getCount();
|
||||
}
|
||||
}
|
@ -2,7 +2,6 @@ from Standard.Base import all
|
||||
|
||||
import Standard.Test
|
||||
|
||||
import project.Model_Spec
|
||||
import project.Column_Spec
|
||||
import project.Csv_Spec
|
||||
import project.Delimited_Read_Spec
|
||||
@ -21,7 +20,6 @@ in_memory_spec =
|
||||
Excel_Spec.spec
|
||||
Json_Spec.spec
|
||||
Table_Spec.spec
|
||||
Model_Spec.spec
|
||||
Aggregate_Column_Spec.spec
|
||||
Aggregate_Spec.spec
|
||||
|
||||
|
@ -1,29 +0,0 @@
|
||||
from Standard.Base import all
|
||||
from Standard.Table import all
|
||||
|
||||
from Standard.Table.Model import Fit_Error
|
||||
|
||||
import Standard.Test
|
||||
|
||||
spec =
|
||||
Test.group "Linear regression" <|
|
||||
column_x = Column.from_vector "x" [2, 3, 5, 7, 9]
|
||||
column_y = Column.from_vector "y" [4, 5, 7, 10, 15]
|
||||
column_y_2 = Column.from_vector "y" [4, 5, 7, 10]
|
||||
|
||||
table = Table.new [column_x, column_y]
|
||||
|
||||
Test.specify "return an error if the column lengths do not match" <|
|
||||
result = Model.linear_regression column_x column_y_2
|
||||
result . should_fail_with Fit_Error
|
||||
|
||||
Test.specify "compute the linear least squares" <|
|
||||
result = Model.linear_regression column_x column_y
|
||||
result.length . should_equal 2
|
||||
result.at 0 . should_equal epsilon=0.001 1.518
|
||||
result.at 1 . should_equal epsilon=0.001 0.304
|
||||
|
||||
Test.specify "compute based on columns in a table" <|
|
||||
result = table.linear_regression "x" "y"
|
||||
result . should_equal (Model.linear_regression column_x column_y)
|
||||
|
108
test/Tests/src/Data/Regression_Spec.enso
Normal file
108
test/Tests/src/Data/Regression_Spec.enso
Normal file
@ -0,0 +1,108 @@
|
||||
from Standard.Base import Nothing, Vector, Number, Decimal, True, Illegal_Argument_Error, False, Regression
|
||||
|
||||
import Standard.Test
|
||||
|
||||
spec =
|
||||
## Regression test data produced using an Excel spreadsheet.
|
||||
https://github.com/enso-org/enso/files/9160145/Regression.tests.xlsx
|
||||
|
||||
double_error = 0.000001
|
||||
|
||||
vector_compare values expected =
|
||||
values.zip expected v->e->
|
||||
case v of
|
||||
Decimal -> v.should_equal e epsilon=double_error
|
||||
_ -> v.should_equal e
|
||||
|
||||
Test.group "Regression" <|
|
||||
Test.specify "return an error if the vector lengths do not match" <|
|
||||
known_xs = [2, 3, 5, 7, 9]
|
||||
known_ys = [4, 5, 7, 10]
|
||||
Regression.fit_least_squares known_xs known_ys . should_fail_with Illegal_Argument_Error
|
||||
|
||||
Test.specify "return an error if the X values are all the same" <|
|
||||
known_xs = [2, 2, 2, 2]
|
||||
known_ys = [4, 5, 7, 10]
|
||||
Regression.fit_least_squares known_xs known_ys . should_fail_with Regression.Fit_Error
|
||||
|
||||
Test.specify "compute the linear trend line" <|
|
||||
known_xs = [2, 3, 5, 7, 9]
|
||||
known_ys = [4, 5, 7, 10, 15]
|
||||
fitted = Regression.fit_least_squares known_xs known_ys
|
||||
fitted.slope . should_equal 1.518292683 epsilon=double_error
|
||||
fitted.intercept . should_equal 0.304878049 epsilon=double_error
|
||||
fitted.r_squared . should_equal 0.959530147 epsilon=double_error
|
||||
|
||||
Test.specify "predict values on a linear trend line" <|
|
||||
known_xs = [2, 3, 5, 7, 9]
|
||||
known_ys = [4, 5, 7, 10, 15]
|
||||
fitted = Regression.fit_least_squares known_xs known_ys
|
||||
test_xs = [1, 4, 6, 8, 10]
|
||||
expected_ys = [1.823171, 6.378049, 9.414634, 12.45122, 15.487805]
|
||||
vector_compare (test_xs.map fitted.predict) expected_ys
|
||||
|
||||
Test.specify "compute the linear trend line with an intercept" <|
|
||||
known_xs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
known_ys = [8.02128, 11.02421, 13.99566, 17.02678, 20.00486, 22.95283, 26.0143, 29.03238, 31.96427, 35.03896]
|
||||
fitted = Regression.fit_least_squares known_xs known_ys (Regression.Linear_Model 100)
|
||||
fitted.slope . should_equal -10.57056558 epsilon=double_error
|
||||
fitted.intercept . should_equal 100.0 epsilon=double_error
|
||||
fitted.r_squared . should_equal 0.9999900045 epsilon=double_error
|
||||
|
||||
Test.specify "compute the exponential trend line" <|
|
||||
known_xs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
known_ys = [0.28652, 0.31735, 0.31963, 0.38482, 0.40056, 0.39013, 0.4976, 0.5665, 0.55457, 0.69135]
|
||||
fitted = Regression.fit_least_squares known_xs known_ys Regression.Exponential_Model
|
||||
fitted.a . should_equal 0.25356436 epsilon=double_error
|
||||
fitted.b . should_equal 0.09358242 epsilon=double_error
|
||||
fitted.r_squared . should_equal 0.9506293649 epsilon=double_error
|
||||
|
||||
Test.specify "predict values on a exponential trend line" <|
|
||||
known_xs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
known_ys = [0.28652, 0.31735, 0.31963, 0.38482, 0.40056, 0.39013, 0.4976, 0.5665, 0.55457, 0.69135]
|
||||
fitted = Regression.fit_least_squares known_xs known_ys Regression.Exponential_Model
|
||||
test_xs = [0, 11, 12, 15]
|
||||
expected_ys = [0.253564, 0.709829, 0.779464, 1.032103]
|
||||
vector_compare (test_xs.map fitted.predict) expected_ys
|
||||
|
||||
Test.specify "compute the exponential trend line with an intercept" <|
|
||||
known_xs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
known_ys = [0.28652, 0.31735, 0.31963, 0.38482, 0.40056, 0.39013, 0.4976, 0.5665, 0.55457, 0.69135]
|
||||
fitted = Regression.fit_least_squares known_xs known_ys (Regression.Exponential_Model 0.2)
|
||||
fitted.a . should_equal 0.2 epsilon=double_error
|
||||
fitted.b . should_equal 0.127482464 epsilon=double_error
|
||||
fitted.r_squared . should_equal 0.9566066546 epsilon=double_error
|
||||
|
||||
Test.specify "compute the logarithmic trend line" <|
|
||||
known_xs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
known_ys = [0.12128, 0.29057, 0.35933, 0.45949, 0.49113, 0.48285, 0.58132, 0.63144, 0.5916, 0.69158]
|
||||
fitted = Regression.fit_least_squares known_xs known_ys Regression.Logarithmic_Model
|
||||
fitted.a . should_equal 0.232702284 epsilon=double_error
|
||||
fitted.b . should_equal 0.11857587 epsilon=double_error
|
||||
fitted.r_squared . should_equal 0.9730840179 epsilon=double_error
|
||||
|
||||
Test.specify "predict values on a logarithmic trend line" <|
|
||||
known_xs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
known_ys = [0.12128, 0.29057, 0.35933, 0.45949, 0.49113, 0.48285, 0.58132, 0.63144, 0.5916, 0.69158]
|
||||
fitted = Regression.fit_least_squares known_xs known_ys Regression.Logarithmic_Model
|
||||
test_xs = [0.1, 11, 12, 15]
|
||||
expected_ys = [-0.417241, 0.676572, 0.696819, 0.748745]
|
||||
vector_compare (test_xs.map fitted.predict) expected_ys
|
||||
|
||||
Test.specify "compute the power trend line" <|
|
||||
known_xs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
known_ys = [0.26128, 0.28144, 0.26353, 0.30247, 0.28677, 0.23992, 0.30586, 0.32785, 0.26324, 0.3411]
|
||||
fitted = Regression.fit_least_squares known_xs known_ys Regression.Power_Model
|
||||
fitted.a . should_equal 0.258838019 epsilon=double_error
|
||||
fitted.b . should_equal 0.065513849 epsilon=double_error
|
||||
fitted.r_squared . should_equal 0.2099579581 epsilon=double_error
|
||||
|
||||
Test.specify "predict values on a power trend line" <|
|
||||
known_xs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
known_ys = [0.26128, 0.28144, 0.26353, 0.30247, 0.28677, 0.23992, 0.30586, 0.32785, 0.26324, 0.3411]
|
||||
fitted = Regression.fit_least_squares known_xs known_ys Regression.Power_Model
|
||||
test_xs = [0.1, 11, 12, 15]
|
||||
expected_ys = [0.222594, 0.302868, 0.3046, 0.309085]
|
||||
vector_compare (test_xs.map fitted.predict) expected_ys
|
||||
|
||||
main = Test.Suite.run_main spec
|
@ -40,6 +40,7 @@ import project.Data.Text_Spec
|
||||
import project.Data.Time.Spec as Time_Spec
|
||||
import project.Data.Vector_Spec
|
||||
import project.Data.Statistics_Spec
|
||||
import project.Data.Regression_Spec
|
||||
import project.Data.Text.Regex_Spec
|
||||
import project.Data.Text.Utils_Spec
|
||||
import project.Data.Text.Default_Regex_Engine_Spec
|
||||
@ -118,5 +119,6 @@ main = Test.Suite.run_main <|
|
||||
Uri_Spec.spec
|
||||
Vector_Spec.spec
|
||||
Statistics_Spec.spec
|
||||
Regression_Spec.spec
|
||||
Warnings_Spec.spec
|
||||
System_Spec.spec
|
||||
|
Loading…
Reference in New Issue
Block a user