From be311457bdca60ef3420c5e8f16bc8c19fd82c17 Mon Sep 17 00:00:00 2001 From: James Dunkerley Date: Fri, 22 Jul 2022 09:41:17 +0100 Subject: [PATCH] Add Linear Regression support for Vectors. (#3601) Adds least squares regression APIs. Covers the basic 4 trend line types from Excel (doesn't cover Polynomial or Moving Average). Removes the old `Model` from the `Standard.Table`. --- CHANGELOG.md | 3 + .../Base/0.0.0-dev/src/Data/Regression.enso | 108 ++++++++++++++++++ .../Base/0.0.0-dev/src/Data/Statistics.enso | 6 +- .../Base/0.0.0-dev/src/Error/Common.enso | 7 ++ .../lib/Standard/Base/0.0.0-dev/src/Main.enso | 6 + .../0.0.0-dev/src/Data/Aggregate_Column.enso | 2 +- .../Table/0.0.0-dev/src/IO/Excel.enso | 10 +- .../src/Internal/Delimited_Reader.enso | 11 +- .../Standard/Table/0.0.0-dev/src/Main.enso | 2 - .../Standard/Table/0.0.0-dev/src/Model.enso | 77 ------------- .../statistics/CorrelationStatistics.java | 51 +++++++++ .../org/enso/base/statistics/FitError.java | 10 ++ .../org/enso/base/statistics/LinearModel.java | 3 + .../org/enso/base/statistics/Regression.java | 58 ++++++++++ test/Table_Tests/src/In_Memory_Tests.enso | 2 - test/Table_Tests/src/Model_Spec.enso | 29 ----- test/Tests/src/Data/Regression_Spec.enso | 108 ++++++++++++++++++ test/Tests/src/Main.enso | 2 + 18 files changed, 362 insertions(+), 133 deletions(-) create mode 100644 distribution/lib/Standard/Base/0.0.0-dev/src/Data/Regression.enso delete mode 100644 distribution/lib/Standard/Table/0.0.0-dev/src/Model.enso create mode 100644 std-bits/base/src/main/java/org/enso/base/statistics/FitError.java create mode 100644 std-bits/base/src/main/java/org/enso/base/statistics/LinearModel.java create mode 100644 std-bits/base/src/main/java/org/enso/base/statistics/Regression.java delete mode 100644 test/Table_Tests/src/Model_Spec.enso create mode 100644 test/Tests/src/Data/Regression_Spec.enso diff --git a/CHANGELOG.md b/CHANGELOG.md index 4bbf16a732..428f686675 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -166,6 +166,8 @@ - [Fixed the case of various type names and library paths][3590] - [Added support for parsing `.pgpass` file and `PG*` environment variables for the Postgres connection][3593] +- [Added `Regression` to the `Standard.Base` library and removed legacy `Model` + type from `Standard.Table`.][3601] [debug-shortcuts]: https://github.com/enso-org/enso/blob/develop/app/gui/docs/product/shortcuts.md#debug @@ -263,6 +265,7 @@ [3588]: https://github.com/enso-org/enso/pull/3588 [3590]: https://github.com/enso-org/enso/pull/3590 [3593]: https://github.com/enso-org/enso/pull/3593 +[3601]: https://github.com/enso-org/enso/pull/3601 #### Enso Compiler diff --git a/distribution/lib/Standard/Base/0.0.0-dev/src/Data/Regression.enso b/distribution/lib/Standard/Base/0.0.0-dev/src/Data/Regression.enso new file mode 100644 index 0000000000..f11d30ede7 --- /dev/null +++ b/distribution/lib/Standard/Base/0.0.0-dev/src/Data/Regression.enso @@ -0,0 +1,108 @@ +from Standard.Base import all + +polyglot java import org.enso.base.statistics.Regression +polyglot java import org.enso.base.statistics.FitError + +type Model + ## Fit a line (y = A x + B) to the data with an optional fixed intercept. + type Linear_Model (intercept:Number|Nothing=Nothing) + + ## Fit a exponential line (y = A exp(B x)) to the data with an optional fixed intercept. + type Exponential_Model (intercept:Number|Nothing=Nothing) + + ## Fit a logarithmic line (y = A log x + B) to the data. + type Logarithmic_Model + + ## Fit a power series (y = A x ^ B) to the data. + type Power_Model + +## Use Least Squares to fit a line to the data. +fit_least_squares : Vector -> Vector -> Model -> Fitted_Model ! Illegal_Argument_Error | Fit_Error +fit_least_squares known_xs known_ys model=Linear_Model = + Illegal_Argument_Error.handle_java_exception <| Fit_Error.handle_java_exception <| case model of + Linear_Model intercept -> + fitted = if intercept.is_nothing then Regression.fit_linear known_xs.to_array known_ys.to_array else + Regression.fit_linear known_xs.to_array known_ys.to_array intercept + Fitted_Linear_Model fitted.slope fitted.intercept fitted.rSquared + Exponential_Model intercept -> + log_ys = ln_series known_ys "Y-values" + fitted = if intercept.is_nothing then Regression.fit_linear known_xs.to_array log_ys.to_array else + Regression.fit_linear known_xs.to_array log_ys.to_array intercept.ln + fitted_model_with_r_squared Fitted_Exponential_Model fitted.intercept.exp fitted.slope known_xs known_ys + Logarithmic_Model -> + log_xs = ln_series known_xs "X-values" + fitted = Regression.fit_linear log_xs.to_array known_ys.to_array + fitted_model_with_r_squared Fitted_Logarithmic_Model fitted.slope fitted.intercept known_xs known_ys + Power_Model -> + log_xs = ln_series known_xs "X-values" + log_ys = ln_series known_ys "Y-values" + fitted = Regression.fit_linear log_xs.to_array log_ys.to_array + fitted_model_with_r_squared Fitted_Power_Model fitted.intercept.exp fitted.slope known_xs known_ys + _ -> Error.throw (Illegal_Argument_Error "Unsupported model.") + +type Fitted_Model + ## Fitted line (y = slope x + intercept). + type Fitted_Linear_Model slope:Number intercept:Number r_squared:Number=0.0 + + ## Fitted exponential line (y = a exp(b x)). + type Fitted_Exponential_Model a:Number b:Number r_squared:Number=0.0 + + ## Fitted logarithmic line (y = a log x + b). + type Fitted_Logarithmic_Model a:Number b:Number r_squared:Number=0.0 + + ## Fitted power series (y = a x ^ b). + type Fitted_Power_Model a:Number b:Number r_squared:Number=0.0 + + ## Display the fitted line. + to_text : Text + to_text = + equation = case self of + Fitted_Linear_Model slope intercept _ -> slope.to_text + " * X + " + intercept.to_text + Fitted_Exponential_Model a b _ -> a.to_text + " * (" + b.to_text + " * X).exp" + Fitted_Logarithmic_Model a b _ -> a.to_text + " * X.ln + " + b.to_text + Fitted_Power_Model a b _ -> a.to_text + " * X ^ " + b.to_text + "Fitted_Model(" + equation + ")" + + ## Use the model to predict a value. + predict : Number -> Number + predict x = case self of + Fitted_Linear_Model slope intercept _ -> slope * x + intercept + Fitted_Exponential_Model a b _ -> a * (b * x).exp + Fitted_Logarithmic_Model a b _ -> a * x.ln + b + Fitted_Power_Model a b _ -> a * (x ^ b) + _ -> Error.throw (Illegal_Argument_Error "Unsupported model.") + +## PRIVATE + Computes the R Squared value for a model and returns a new instance. +fitted_model_with_r_squared : Any -> Number -> Number -> Vector -> Vector -> Fitted_Model +fitted_model_with_r_squared constructor a b known_xs known_ys = + model = constructor a b + r_squared = known_ys.compute (Statistics.R_Squared (known_xs.map model.predict)) + constructor a b r_squared + +## PRIVATE + + Computes the natural log series as long as all values are positive. +ln_series : Vector -> Vector ! Illegal_Argument_Error +ln_series xs series_name="Values" = + ln_with_panic x = if x.is_nothing then Nothing else + if x <= 0 then Panic.throw (Illegal_Argument_Error (series_name + " must be positive.")) else x.ln + Panic.recover Illegal_Argument_Error <| xs.map ln_with_panic + +## PRIVATE + + An error thrown when the linear regression cannot be computed. + + Arguments: + - message: The error message. +type Fit_Error message + +## PRIVATE + + Converts the `Fit_Error` to a human-readable representation. +Fit_Error.to_display_text : Text +Fit_Error.to_display_text = "Could not fit the model: " + self.message.to_text + +## PRIVATE +Fit_Error.handle_java_exception = + Panic.catch_java FitError handler=(java_exception-> Error.throw (Fit_Error java_exception.getMessage)) diff --git a/distribution/lib/Standard/Base/0.0.0-dev/src/Data/Statistics.enso b/distribution/lib/Standard/Base/0.0.0-dev/src/Data/Statistics.enso index 8a3c2d8ee7..4795127c59 100644 --- a/distribution/lib/Standard/Base/0.0.0-dev/src/Data/Statistics.enso +++ b/distribution/lib/Standard/Base/0.0.0-dev/src/Data/Statistics.enso @@ -11,7 +11,6 @@ polyglot java import org.enso.base.statistics.CountMinMax polyglot java import org.enso.base.statistics.CorrelationStatistics polyglot java import org.enso.base.statistics.Rank -polyglot java import java.lang.IllegalArgumentException polyglot java import java.lang.ClassCastException polyglot java import java.lang.NullPointerException @@ -185,10 +184,7 @@ wrap_java_call ~function = report_unsupported _ = Error.throw (Illegal_Argument_Error ("Can only compute correlations on numerical data sets.")) handle_unsupported = Panic.catch Unsupported_Argument_Types handler=report_unsupported - report_illegal caught_panic = Error.throw (Illegal_Argument_Error caught_panic.payload.cause.getMessage) - handle_illegal = Panic.catch IllegalArgumentException handler=report_illegal - - handle_unsupported <| handle_illegal <| function + handle_unsupported <| Illegal_Argument_Error.handle_java_exception <| function ## PRIVATE diff --git a/distribution/lib/Standard/Base/0.0.0-dev/src/Error/Common.enso b/distribution/lib/Standard/Base/0.0.0-dev/src/Error/Common.enso index 9407339730..96c522a55c 100644 --- a/distribution/lib/Standard/Base/0.0.0-dev/src/Error/Common.enso +++ b/distribution/lib/Standard/Base/0.0.0-dev/src/Error/Common.enso @@ -2,6 +2,8 @@ from Standard.Base import all import Standard.Base.Data.Json import Standard.Base.Runtime +polyglot java import java.lang.IllegalArgumentException + ## Dataflow errors. type Error @@ -191,6 +193,11 @@ type Illegal_Argument_Error - cause: (optional) another error that is the cause of this one. type Illegal_Argument_Error message cause=Nothing + ## PRIVATE + Capture a Java IllegalArgumentException and rethrow + handle_java_exception = + Panic.catch_java IllegalArgumentException handler=(cause-> Error.throw (Illegal_Argument_Error cause.getMessage cause)) + ## PRIVATE Wraps a dataflow error lifted to a panic, making possible to distinguish it from other panics. diff --git a/distribution/lib/Standard/Base/0.0.0-dev/src/Main.enso b/distribution/lib/Standard/Base/0.0.0-dev/src/Main.enso index a7951af416..6d37190239 100644 --- a/distribution/lib/Standard/Base/0.0.0-dev/src/Main.enso +++ b/distribution/lib/Standard/Base/0.0.0-dev/src/Main.enso @@ -34,6 +34,9 @@ import project.System.Environment import project.System.File import project.Data.Text.Regex.Mode as Regex_Mode import project.Warning +import project.Data.Statistics +import project.Data.Statistics.Rank_Method +import project.Data.Regression export project.Data.Interval export project.Data.Json @@ -42,6 +45,9 @@ export project.Data.Map export project.Data.Maybe export project.Data.Ordering export project.Data.Ordering.Sort_Direction +export project.Data.Regression +export project.Data.Statistics +export project.Data.Statistics.Rank_Method export project.Data.Vector export project.IO export project.Math diff --git a/distribution/lib/Standard/Table/0.0.0-dev/src/Data/Aggregate_Column.enso b/distribution/lib/Standard/Table/0.0.0-dev/src/Data/Aggregate_Column.enso index 09ffe4d069..c8a41244bb 100644 --- a/distribution/lib/Standard/Table/0.0.0-dev/src/Data/Aggregate_Column.enso +++ b/distribution/lib/Standard/Table/0.0.0-dev/src/Data/Aggregate_Column.enso @@ -114,7 +114,7 @@ type Aggregate_Column - column: column (specified by name, index or Column object) to compute standard deviation. - name: name of new column. - - population argument specifies if group is a sample or the population + - population: specifies if group is a sample or the population type Standard_Deviation (column:Column|Text|Integer) (new_name:Text|Nothing=Nothing) (population:Boolean=False) ## Creates a new column with the values concatenated together. `Nothing` diff --git a/distribution/lib/Standard/Table/0.0.0-dev/src/IO/Excel.enso b/distribution/lib/Standard/Table/0.0.0-dev/src/IO/Excel.enso index a4da809e3c..62480c6c72 100644 --- a/distribution/lib/Standard/Table/0.0.0-dev/src/IO/Excel.enso +++ b/distribution/lib/Standard/Table/0.0.0-dev/src/IO/Excel.enso @@ -18,7 +18,6 @@ polyglot java import org.enso.table.error.ExistingDataException polyglot java import org.enso.table.error.RangeExceededException polyglot java import org.enso.table.error.InvalidLocationException -polyglot java import java.lang.IllegalArgumentException polyglot java import java.lang.IllegalStateException polyglot java import java.io.IOException polyglot java import org.apache.poi.UnsupportedFileFormatException @@ -114,8 +113,7 @@ type Excel_Range ## Creates a Range from an address. from_address : Text -> Excel_Range from_address address = - illegal_argument caught_panic = Error.throw (Illegal_Argument_Error caught_panic.payload.cause.getMessage caught_panic.payload.cause) - Panic.catch IllegalArgumentException handler=illegal_argument <| + Illegal_Argument_Error.handle_java_exception <| Excel_Range (Java_Range.new address) ## Create a Range for a single cell. @@ -281,15 +279,11 @@ handle_writer ~writer = throw_existing_data caught_panic = Error.throw (Existing_Data caught_panic.payload.cause.getMessage) handle_existing_data = Panic.catch ExistingDataException handler=throw_existing_data - ## Illegal argument can occur if appending in an invalid mode - illegal_argument caught_panic = Error.throw (Illegal_Argument_Error caught_panic.payload.cause.getMessage caught_panic.payload.cause) - handle_illegal_argument = Panic.catch IllegalArgumentException handler=illegal_argument - ## Should be impossible - occurs if no fallback serializer is provided. throw_illegal_state caught_panic = Panic.throw (Illegal_State_Error caught_panic.payload.cause.getMessage) handle_illegal_state = Panic.catch IllegalStateException handler=throw_illegal_state handle_illegal_state <| Column_Name_Mismatch.handle_java_exception <| Column_Count_Mismatch.handle_java_exception <| handle_bad_location <| - handle_illegal_argument <| handle_range_exceeded <| handle_existing_data <| + Illegal_Argument_Error.handle_java_exception <| handle_range_exceeded <| handle_existing_data <| writer diff --git a/distribution/lib/Standard/Table/0.0.0-dev/src/Internal/Delimited_Reader.enso b/distribution/lib/Standard/Table/0.0.0-dev/src/Internal/Delimited_Reader.enso index 23e0aaa5c5..002a1e2e25 100644 --- a/distribution/lib/Standard/Table/0.0.0-dev/src/Internal/Delimited_Reader.enso +++ b/distribution/lib/Standard/Table/0.0.0-dev/src/Internal/Delimited_Reader.enso @@ -17,7 +17,6 @@ polyglot java import org.enso.table.parsing.problems.MismatchedQuote polyglot java import org.enso.table.parsing.problems.AdditionalInvalidRows polyglot java import org.enso.table.util.problems.DuplicateNames polyglot java import org.enso.table.util.problems.InvalidNames -polyglot java import java.lang.IllegalArgumentException polyglot java import java.io.IOException polyglot java import com.univocity.parsers.common.TextParsingException polyglot java import org.enso.base.Encoding_Utils @@ -91,7 +90,7 @@ read_stream format stream on_problems max_columns=default_max_columns related_fi integer. read_from_reader : Delimited -> Reader -> Problem_Behavior -> Integer -> Any read_from_reader format java_reader on_problems max_columns=4096 = - handle_illegal_arguments <| handle_parsing_failure <| handle_parsing_exception <| + Illegal_Argument_Error.handle_java_exception <| handle_parsing_failure <| handle_parsing_exception <| reader = prepare_delimited_reader java_reader format max_columns on_problems result_with_problems = reader.read parsing_problems = Vector.Vector (result_with_problems.problems) . map translate_reader_problem @@ -160,7 +159,7 @@ type Detected_File_Metadata detect_metadata : File -> File_Format.Delimited -> Detected_Headers detect_metadata file format = on_problems = Ignore - result = handle_io_exception file <| handle_illegal_arguments <| handle_parsing_failure <| handle_parsing_exception <| + result = handle_io_exception file <| Illegal_Argument_Error.handle_java_exception <| handle_parsing_failure <| handle_parsing_exception <| file.with_input_stream [File.Option.Read] stream-> stream.with_stream_decoder format.encoding on_problems java_reader-> ## We use the default `max_columns` setting. If we want to be able to @@ -179,12 +178,6 @@ detect_metadata file format = Detected_File_Metadata headers line_separator result.catch File.File_Not_Found (_->(Detected_File_Metadata Nothing Nothing)) -## PRIVATE -handle_illegal_arguments = - translate_illegal_argument caught_panic = - Error.throw (Illegal_Argument_Error caught_panic.payload.cause.getMessage) - Panic.catch IllegalArgumentException handler=translate_illegal_argument - ## PRIVATE handle_parsing_failure = translate_parsing_failure caught_panic = diff --git a/distribution/lib/Standard/Table/0.0.0-dev/src/Main.enso b/distribution/lib/Standard/Table/0.0.0-dev/src/Main.enso index c82a28272f..807b77b05a 100644 --- a/distribution/lib/Standard/Table/0.0.0-dev/src/Main.enso +++ b/distribution/lib/Standard/Table/0.0.0-dev/src/Main.enso @@ -6,12 +6,10 @@ import Standard.Table.IO.File_Format import Standard.Table.IO.Excel import Standard.Table.Data.Table import Standard.Table.Data.Column -import Standard.Table.Model from Standard.Table.IO.Excel export Excel_Section, Excel_Range export Standard.Table.Data.Column -export Standard.Table.Model export Standard.Table.IO.File_Read export Standard.Table.IO.File_Format diff --git a/distribution/lib/Standard/Table/0.0.0-dev/src/Model.enso b/distribution/lib/Standard/Table/0.0.0-dev/src/Model.enso deleted file mode 100644 index 96146bd327..0000000000 --- a/distribution/lib/Standard/Table/0.0.0-dev/src/Model.enso +++ /dev/null @@ -1,77 +0,0 @@ -from Standard.Base import all - -from Standard.Table import Column, Table - -## Compute the linear regression between the x and y coordinates, returning a - vector containing the slope in the first position and the bias in the second. - - Arguments: - - x_column: The name of the column in `self` containing the x values. - - y_column: The name of the column in `self` containing the y values. - - If the columns don't match in length, it throws a `Fit_Error`. - - > Example - Compute the linear regression between two columns in a table. - - from Standard.Table import all - - example_linear_regression = - column_x = Column.from_vector "x" [1, 2, 3, 4, 5] - column_y = Column.from_vector "y" [2, 4, 6, 8, 10] - table = Table.new [column_x, column_y] - table.linear_regression "x" "y" -Table.Table.linear_regression : Text -> Text -> Vector Number ! Fit_Error -Table.Table.linear_regression x_column y_column = - x_values = self.at x_column - y_values = self.at y_column - linear_regression x_values y_values - -## Compute the linear regression between the x and y coordinates, returning a - vector containing the slope in the first position and the bias in the second. - - Arguments: - - x_values: The column of x coordinate values for each coordinate pair. - - y_values: The column of y coordinate values for each coordinate pair. - - If the columns don't match in length, it throws a `Fit_Error`. - - > Example - Compute the linear regression between two columns. - - from Standard.Table import all - - example_linear_regression = - column_x = Column.from_vector "x" [1, 2, 3, 4, 5] - column_y = Column.from_vector "y" [2, 4, 6, 8, 10] - Model.linear_regression column_x column_y -linear_regression : Column -> Column -> Vector Number ! Fit_Error -linear_regression x_values y_values = - if x_values.length != y_values.length then Error.throw (Fit_Error "Columns have different lengths.") else - n = x_values.length - x_squared = x_values.map (^2) - x_y = x_values * y_values - - slope_numerator = (n * x_y.sum) - (x_values.sum * y_values.sum) - slope_denominator = (n * x_squared.sum) - (x_values.sum ^ 2) - slope = slope_numerator / slope_denominator - - bias_numerator = y_values.sum - (slope * x_values.sum) - bias = bias_numerator / n - - [slope, bias] - -## PRIVATE - - An error thrown when the linear regression cannot be computed. - - Arguments: - - message: The error message. -type Fit_Error message - -## PRIVATE - - Converts the `Fit_Error` to a human-readable representation. -Fit_Error.to_display_text : Text -Fit_Error.to_display_text = "Could not fit the model: " + self.message.to_text - diff --git a/std-bits/base/src/main/java/org/enso/base/statistics/CorrelationStatistics.java b/std-bits/base/src/main/java/org/enso/base/statistics/CorrelationStatistics.java index 04fa6a7311..7a46223d56 100644 --- a/std-bits/base/src/main/java/org/enso/base/statistics/CorrelationStatistics.java +++ b/std-bits/base/src/main/java/org/enso/base/statistics/CorrelationStatistics.java @@ -22,6 +22,51 @@ public class CorrelationStatistics { totalXY += x * y; } + /* + * Count of non-null pairs of values. + */ + public long getCount() { + return count; + } + + /* + * Sum of X values. + */ + public double getTotalX() { + return totalX; + } + + /* + * Sum of Y values. + */ + public double getTotalY() { + return totalY; + } + + /* + * Sum of X^2 values. + */ + public double getTotalXX() { + return totalXX; + } + + /* + * Sum of X * Y values. + */ + public double getTotalXY() { + return totalXY; + } + + /* + * Sum of Y^2 values. + */ + public double getTotalYY() { + return totalYY; + } + + /* + * Compute the covariance of X and Y. + */ public double covariance() { if (count < 2) { return Double.NaN; @@ -30,6 +75,9 @@ public class CorrelationStatistics { return (totalXY - totalX * totalY / count) / count; } + /* + * Compute the Pearson correlation between X and Y. + */ public double pearsonCorrelation() { if (count < 2) { return Double.NaN; @@ -40,6 +88,9 @@ public class CorrelationStatistics { return (count * totalXY - totalX * totalY) / (n_stdev_x * n_stdev_y); } + /* + * Compute the R-Squared between X and Y (which equals the Pearson correlation ^ 2). + */ public double rSquared() { double correl = this.pearsonCorrelation(); return correl * correl; diff --git a/std-bits/base/src/main/java/org/enso/base/statistics/FitError.java b/std-bits/base/src/main/java/org/enso/base/statistics/FitError.java new file mode 100644 index 0000000000..cc3a99e9e0 --- /dev/null +++ b/std-bits/base/src/main/java/org/enso/base/statistics/FitError.java @@ -0,0 +1,10 @@ +package org.enso.base.statistics; + +/* + A class for exceptions thrown when fitting a model. +*/ +public class FitError extends Exception { + public FitError(String message) { + super(message); + } +} diff --git a/std-bits/base/src/main/java/org/enso/base/statistics/LinearModel.java b/std-bits/base/src/main/java/org/enso/base/statistics/LinearModel.java new file mode 100644 index 0000000000..8082a5ec85 --- /dev/null +++ b/std-bits/base/src/main/java/org/enso/base/statistics/LinearModel.java @@ -0,0 +1,3 @@ +package org.enso.base.statistics; + +public record LinearModel(double slope, double intercept, double rSquared) {} diff --git a/std-bits/base/src/main/java/org/enso/base/statistics/Regression.java b/std-bits/base/src/main/java/org/enso/base/statistics/Regression.java new file mode 100644 index 0000000000..e812d50817 --- /dev/null +++ b/std-bits/base/src/main/java/org/enso/base/statistics/Regression.java @@ -0,0 +1,58 @@ +package org.enso.base.statistics; + +public class Regression { + /** + * Performs a least squares fit of a line to the data. + * + * @param known_xs Set of known X values. + * @param known_ys Set of known Y values. + * @return A fitted linear model (y = Intercept + Slope x) and the r-squared value. + * @throws IllegalArgumentException if the number of elements in the arrays is different or a + * singular X value is provided. + */ + public static LinearModel fit_linear(Double[] known_xs, Double[] known_ys) + throws IllegalArgumentException, FitError { + CorrelationStatistics stats = CorrelationStatistics.compute(known_xs, known_ys); + + double denominator = denominator(stats); + if (denominator == 0) { + throw new FitError("Singular X value."); + } + + double slope = slope(stats, denominator); + return new LinearModel(slope, intercept(stats, slope), stats.rSquared()); + } + + /** + * Performs a least squares fit of a line to the data with a given intercept. + * + * @param known_xs Set of known X values. + * @param known_ys Set of known Y values. + * @param intercept The intercept of the line. + * @return A fitted linear model (y = Intercept + Slope x) and the r-squared value. + * @throws IllegalArgumentException if the number of elements in the arrays is different or a + * singular X value is provided. + */ + public static LinearModel fit_linear(Double[] known_xs, Double[] known_ys, double intercept) + throws IllegalArgumentException { + CorrelationStatistics stats = CorrelationStatistics.compute(known_xs, known_ys); + return new LinearModel(slopeWithIntercept(stats, intercept), intercept, stats.rSquared()); + } + + private static double denominator(CorrelationStatistics stats) { + return stats.getTotalXX() - stats.getTotalX() * stats.getTotalX() / stats.getCount(); + } + + private static double slope(CorrelationStatistics stats, double denominator) { + return (stats.getTotalXY() - stats.getTotalX() * stats.getTotalY() / stats.getCount()) + / denominator; + } + + private static double slopeWithIntercept(CorrelationStatistics stats, double intercept) { + return (-intercept * stats.getTotalX() + stats.getTotalXY()) / stats.getTotalXX(); + } + + private static double intercept(CorrelationStatistics stats, double slope) { + return (stats.getTotalY() - stats.getTotalX() * slope) / stats.getCount(); + } +} diff --git a/test/Table_Tests/src/In_Memory_Tests.enso b/test/Table_Tests/src/In_Memory_Tests.enso index 7f1fd3b0b5..6c19530a8e 100644 --- a/test/Table_Tests/src/In_Memory_Tests.enso +++ b/test/Table_Tests/src/In_Memory_Tests.enso @@ -2,7 +2,6 @@ from Standard.Base import all import Standard.Test -import project.Model_Spec import project.Column_Spec import project.Csv_Spec import project.Delimited_Read_Spec @@ -21,7 +20,6 @@ in_memory_spec = Excel_Spec.spec Json_Spec.spec Table_Spec.spec - Model_Spec.spec Aggregate_Column_Spec.spec Aggregate_Spec.spec diff --git a/test/Table_Tests/src/Model_Spec.enso b/test/Table_Tests/src/Model_Spec.enso deleted file mode 100644 index 43c33033cf..0000000000 --- a/test/Table_Tests/src/Model_Spec.enso +++ /dev/null @@ -1,29 +0,0 @@ -from Standard.Base import all -from Standard.Table import all - -from Standard.Table.Model import Fit_Error - -import Standard.Test - -spec = - Test.group "Linear regression" <| - column_x = Column.from_vector "x" [2, 3, 5, 7, 9] - column_y = Column.from_vector "y" [4, 5, 7, 10, 15] - column_y_2 = Column.from_vector "y" [4, 5, 7, 10] - - table = Table.new [column_x, column_y] - - Test.specify "return an error if the column lengths do not match" <| - result = Model.linear_regression column_x column_y_2 - result . should_fail_with Fit_Error - - Test.specify "compute the linear least squares" <| - result = Model.linear_regression column_x column_y - result.length . should_equal 2 - result.at 0 . should_equal epsilon=0.001 1.518 - result.at 1 . should_equal epsilon=0.001 0.304 - - Test.specify "compute based on columns in a table" <| - result = table.linear_regression "x" "y" - result . should_equal (Model.linear_regression column_x column_y) - diff --git a/test/Tests/src/Data/Regression_Spec.enso b/test/Tests/src/Data/Regression_Spec.enso new file mode 100644 index 0000000000..4735a8b64d --- /dev/null +++ b/test/Tests/src/Data/Regression_Spec.enso @@ -0,0 +1,108 @@ +from Standard.Base import Nothing, Vector, Number, Decimal, True, Illegal_Argument_Error, False, Regression + +import Standard.Test + +spec = + ## Regression test data produced using an Excel spreadsheet. + https://github.com/enso-org/enso/files/9160145/Regression.tests.xlsx + + double_error = 0.000001 + + vector_compare values expected = + values.zip expected v->e-> + case v of + Decimal -> v.should_equal e epsilon=double_error + _ -> v.should_equal e + + Test.group "Regression" <| + Test.specify "return an error if the vector lengths do not match" <| + known_xs = [2, 3, 5, 7, 9] + known_ys = [4, 5, 7, 10] + Regression.fit_least_squares known_xs known_ys . should_fail_with Illegal_Argument_Error + + Test.specify "return an error if the X values are all the same" <| + known_xs = [2, 2, 2, 2] + known_ys = [4, 5, 7, 10] + Regression.fit_least_squares known_xs known_ys . should_fail_with Regression.Fit_Error + + Test.specify "compute the linear trend line" <| + known_xs = [2, 3, 5, 7, 9] + known_ys = [4, 5, 7, 10, 15] + fitted = Regression.fit_least_squares known_xs known_ys + fitted.slope . should_equal 1.518292683 epsilon=double_error + fitted.intercept . should_equal 0.304878049 epsilon=double_error + fitted.r_squared . should_equal 0.959530147 epsilon=double_error + + Test.specify "predict values on a linear trend line" <| + known_xs = [2, 3, 5, 7, 9] + known_ys = [4, 5, 7, 10, 15] + fitted = Regression.fit_least_squares known_xs known_ys + test_xs = [1, 4, 6, 8, 10] + expected_ys = [1.823171, 6.378049, 9.414634, 12.45122, 15.487805] + vector_compare (test_xs.map fitted.predict) expected_ys + + Test.specify "compute the linear trend line with an intercept" <| + known_xs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + known_ys = [8.02128, 11.02421, 13.99566, 17.02678, 20.00486, 22.95283, 26.0143, 29.03238, 31.96427, 35.03896] + fitted = Regression.fit_least_squares known_xs known_ys (Regression.Linear_Model 100) + fitted.slope . should_equal -10.57056558 epsilon=double_error + fitted.intercept . should_equal 100.0 epsilon=double_error + fitted.r_squared . should_equal 0.9999900045 epsilon=double_error + + Test.specify "compute the exponential trend line" <| + known_xs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + known_ys = [0.28652, 0.31735, 0.31963, 0.38482, 0.40056, 0.39013, 0.4976, 0.5665, 0.55457, 0.69135] + fitted = Regression.fit_least_squares known_xs known_ys Regression.Exponential_Model + fitted.a . should_equal 0.25356436 epsilon=double_error + fitted.b . should_equal 0.09358242 epsilon=double_error + fitted.r_squared . should_equal 0.9506293649 epsilon=double_error + + Test.specify "predict values on a exponential trend line" <| + known_xs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + known_ys = [0.28652, 0.31735, 0.31963, 0.38482, 0.40056, 0.39013, 0.4976, 0.5665, 0.55457, 0.69135] + fitted = Regression.fit_least_squares known_xs known_ys Regression.Exponential_Model + test_xs = [0, 11, 12, 15] + expected_ys = [0.253564, 0.709829, 0.779464, 1.032103] + vector_compare (test_xs.map fitted.predict) expected_ys + + Test.specify "compute the exponential trend line with an intercept" <| + known_xs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + known_ys = [0.28652, 0.31735, 0.31963, 0.38482, 0.40056, 0.39013, 0.4976, 0.5665, 0.55457, 0.69135] + fitted = Regression.fit_least_squares known_xs known_ys (Regression.Exponential_Model 0.2) + fitted.a . should_equal 0.2 epsilon=double_error + fitted.b . should_equal 0.127482464 epsilon=double_error + fitted.r_squared . should_equal 0.9566066546 epsilon=double_error + + Test.specify "compute the logarithmic trend line" <| + known_xs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + known_ys = [0.12128, 0.29057, 0.35933, 0.45949, 0.49113, 0.48285, 0.58132, 0.63144, 0.5916, 0.69158] + fitted = Regression.fit_least_squares known_xs known_ys Regression.Logarithmic_Model + fitted.a . should_equal 0.232702284 epsilon=double_error + fitted.b . should_equal 0.11857587 epsilon=double_error + fitted.r_squared . should_equal 0.9730840179 epsilon=double_error + + Test.specify "predict values on a logarithmic trend line" <| + known_xs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + known_ys = [0.12128, 0.29057, 0.35933, 0.45949, 0.49113, 0.48285, 0.58132, 0.63144, 0.5916, 0.69158] + fitted = Regression.fit_least_squares known_xs known_ys Regression.Logarithmic_Model + test_xs = [0.1, 11, 12, 15] + expected_ys = [-0.417241, 0.676572, 0.696819, 0.748745] + vector_compare (test_xs.map fitted.predict) expected_ys + + Test.specify "compute the power trend line" <| + known_xs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + known_ys = [0.26128, 0.28144, 0.26353, 0.30247, 0.28677, 0.23992, 0.30586, 0.32785, 0.26324, 0.3411] + fitted = Regression.fit_least_squares known_xs known_ys Regression.Power_Model + fitted.a . should_equal 0.258838019 epsilon=double_error + fitted.b . should_equal 0.065513849 epsilon=double_error + fitted.r_squared . should_equal 0.2099579581 epsilon=double_error + + Test.specify "predict values on a power trend line" <| + known_xs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + known_ys = [0.26128, 0.28144, 0.26353, 0.30247, 0.28677, 0.23992, 0.30586, 0.32785, 0.26324, 0.3411] + fitted = Regression.fit_least_squares known_xs known_ys Regression.Power_Model + test_xs = [0.1, 11, 12, 15] + expected_ys = [0.222594, 0.302868, 0.3046, 0.309085] + vector_compare (test_xs.map fitted.predict) expected_ys + +main = Test.Suite.run_main spec diff --git a/test/Tests/src/Main.enso b/test/Tests/src/Main.enso index d816482fb1..ecdd203c71 100644 --- a/test/Tests/src/Main.enso +++ b/test/Tests/src/Main.enso @@ -40,6 +40,7 @@ import project.Data.Text_Spec import project.Data.Time.Spec as Time_Spec import project.Data.Vector_Spec import project.Data.Statistics_Spec +import project.Data.Regression_Spec import project.Data.Text.Regex_Spec import project.Data.Text.Utils_Spec import project.Data.Text.Default_Regex_Engine_Spec @@ -118,5 +119,6 @@ main = Test.Suite.run_main <| Uri_Spec.spec Vector_Spec.spec Statistics_Spec.spec + Regression_Spec.spec Warnings_Spec.spec System_Spec.spec