Add Linear Regression support for Vectors. (#3601)

Adds least squares regression APIs. Covers the basic 4 trend line types from Excel (doesn't cover Polynomial or Moving Average).
Removes the old `Model` from the `Standard.Table`.
This commit is contained in:
James Dunkerley 2022-07-22 09:41:17 +01:00 committed by GitHub
parent 5b4aac0138
commit be311457bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 362 additions and 133 deletions

View File

@ -166,6 +166,8 @@
- [Fixed the case of various type names and library paths][3590]
- [Added support for parsing `.pgpass` file and `PG*` environment variables for
the Postgres connection][3593]
- [Added `Regression` to the `Standard.Base` library and removed legacy `Model`
type from `Standard.Table`.][3601]
[debug-shortcuts]:
https://github.com/enso-org/enso/blob/develop/app/gui/docs/product/shortcuts.md#debug
@ -263,6 +265,7 @@
[3588]: https://github.com/enso-org/enso/pull/3588
[3590]: https://github.com/enso-org/enso/pull/3590
[3593]: https://github.com/enso-org/enso/pull/3593
[3601]: https://github.com/enso-org/enso/pull/3601
#### Enso Compiler

View File

@ -0,0 +1,108 @@
from Standard.Base import all
polyglot java import org.enso.base.statistics.Regression
polyglot java import org.enso.base.statistics.FitError
type Model
## Fit a line (y = A x + B) to the data with an optional fixed intercept.
type Linear_Model (intercept:Number|Nothing=Nothing)
## Fit a exponential line (y = A exp(B x)) to the data with an optional fixed intercept.
type Exponential_Model (intercept:Number|Nothing=Nothing)
## Fit a logarithmic line (y = A log x + B) to the data.
type Logarithmic_Model
## Fit a power series (y = A x ^ B) to the data.
type Power_Model
## Use Least Squares to fit a line to the data.
fit_least_squares : Vector -> Vector -> Model -> Fitted_Model ! Illegal_Argument_Error | Fit_Error
fit_least_squares known_xs known_ys model=Linear_Model =
Illegal_Argument_Error.handle_java_exception <| Fit_Error.handle_java_exception <| case model of
Linear_Model intercept ->
fitted = if intercept.is_nothing then Regression.fit_linear known_xs.to_array known_ys.to_array else
Regression.fit_linear known_xs.to_array known_ys.to_array intercept
Fitted_Linear_Model fitted.slope fitted.intercept fitted.rSquared
Exponential_Model intercept ->
log_ys = ln_series known_ys "Y-values"
fitted = if intercept.is_nothing then Regression.fit_linear known_xs.to_array log_ys.to_array else
Regression.fit_linear known_xs.to_array log_ys.to_array intercept.ln
fitted_model_with_r_squared Fitted_Exponential_Model fitted.intercept.exp fitted.slope known_xs known_ys
Logarithmic_Model ->
log_xs = ln_series known_xs "X-values"
fitted = Regression.fit_linear log_xs.to_array known_ys.to_array
fitted_model_with_r_squared Fitted_Logarithmic_Model fitted.slope fitted.intercept known_xs known_ys
Power_Model ->
log_xs = ln_series known_xs "X-values"
log_ys = ln_series known_ys "Y-values"
fitted = Regression.fit_linear log_xs.to_array log_ys.to_array
fitted_model_with_r_squared Fitted_Power_Model fitted.intercept.exp fitted.slope known_xs known_ys
_ -> Error.throw (Illegal_Argument_Error "Unsupported model.")
type Fitted_Model
## Fitted line (y = slope x + intercept).
type Fitted_Linear_Model slope:Number intercept:Number r_squared:Number=0.0
## Fitted exponential line (y = a exp(b x)).
type Fitted_Exponential_Model a:Number b:Number r_squared:Number=0.0
## Fitted logarithmic line (y = a log x + b).
type Fitted_Logarithmic_Model a:Number b:Number r_squared:Number=0.0
## Fitted power series (y = a x ^ b).
type Fitted_Power_Model a:Number b:Number r_squared:Number=0.0
## Display the fitted line.
to_text : Text
to_text =
equation = case self of
Fitted_Linear_Model slope intercept _ -> slope.to_text + " * X + " + intercept.to_text
Fitted_Exponential_Model a b _ -> a.to_text + " * (" + b.to_text + " * X).exp"
Fitted_Logarithmic_Model a b _ -> a.to_text + " * X.ln + " + b.to_text
Fitted_Power_Model a b _ -> a.to_text + " * X ^ " + b.to_text
"Fitted_Model(" + equation + ")"
## Use the model to predict a value.
predict : Number -> Number
predict x = case self of
Fitted_Linear_Model slope intercept _ -> slope * x + intercept
Fitted_Exponential_Model a b _ -> a * (b * x).exp
Fitted_Logarithmic_Model a b _ -> a * x.ln + b
Fitted_Power_Model a b _ -> a * (x ^ b)
_ -> Error.throw (Illegal_Argument_Error "Unsupported model.")
## PRIVATE
Computes the R Squared value for a model and returns a new instance.
fitted_model_with_r_squared : Any -> Number -> Number -> Vector -> Vector -> Fitted_Model
fitted_model_with_r_squared constructor a b known_xs known_ys =
model = constructor a b
r_squared = known_ys.compute (Statistics.R_Squared (known_xs.map model.predict))
constructor a b r_squared
## PRIVATE
Computes the natural log series as long as all values are positive.
ln_series : Vector -> Vector ! Illegal_Argument_Error
ln_series xs series_name="Values" =
ln_with_panic x = if x.is_nothing then Nothing else
if x <= 0 then Panic.throw (Illegal_Argument_Error (series_name + " must be positive.")) else x.ln
Panic.recover Illegal_Argument_Error <| xs.map ln_with_panic
## PRIVATE
An error thrown when the linear regression cannot be computed.
Arguments:
- message: The error message.
type Fit_Error message
## PRIVATE
Converts the `Fit_Error` to a human-readable representation.
Fit_Error.to_display_text : Text
Fit_Error.to_display_text = "Could not fit the model: " + self.message.to_text
## PRIVATE
Fit_Error.handle_java_exception =
Panic.catch_java FitError handler=(java_exception-> Error.throw (Fit_Error java_exception.getMessage))

View File

@ -11,7 +11,6 @@ polyglot java import org.enso.base.statistics.CountMinMax
polyglot java import org.enso.base.statistics.CorrelationStatistics
polyglot java import org.enso.base.statistics.Rank
polyglot java import java.lang.IllegalArgumentException
polyglot java import java.lang.ClassCastException
polyglot java import java.lang.NullPointerException
@ -185,10 +184,7 @@ wrap_java_call ~function =
report_unsupported _ = Error.throw (Illegal_Argument_Error ("Can only compute correlations on numerical data sets."))
handle_unsupported = Panic.catch Unsupported_Argument_Types handler=report_unsupported
report_illegal caught_panic = Error.throw (Illegal_Argument_Error caught_panic.payload.cause.getMessage)
handle_illegal = Panic.catch IllegalArgumentException handler=report_illegal
handle_unsupported <| handle_illegal <| function
handle_unsupported <| Illegal_Argument_Error.handle_java_exception <| function
## PRIVATE

View File

@ -2,6 +2,8 @@ from Standard.Base import all
import Standard.Base.Data.Json
import Standard.Base.Runtime
polyglot java import java.lang.IllegalArgumentException
## Dataflow errors.
type Error
@ -191,6 +193,11 @@ type Illegal_Argument_Error
- cause: (optional) another error that is the cause of this one.
type Illegal_Argument_Error message cause=Nothing
## PRIVATE
Capture a Java IllegalArgumentException and rethrow
handle_java_exception =
Panic.catch_java IllegalArgumentException handler=(cause-> Error.throw (Illegal_Argument_Error cause.getMessage cause))
## PRIVATE
Wraps a dataflow error lifted to a panic, making possible to distinguish it
from other panics.

View File

@ -34,6 +34,9 @@ import project.System.Environment
import project.System.File
import project.Data.Text.Regex.Mode as Regex_Mode
import project.Warning
import project.Data.Statistics
import project.Data.Statistics.Rank_Method
import project.Data.Regression
export project.Data.Interval
export project.Data.Json
@ -42,6 +45,9 @@ export project.Data.Map
export project.Data.Maybe
export project.Data.Ordering
export project.Data.Ordering.Sort_Direction
export project.Data.Regression
export project.Data.Statistics
export project.Data.Statistics.Rank_Method
export project.Data.Vector
export project.IO
export project.Math

View File

@ -114,7 +114,7 @@ type Aggregate_Column
- column: column (specified by name, index or Column object) to compute
standard deviation.
- name: name of new column.
- population argument specifies if group is a sample or the population
- population: specifies if group is a sample or the population
type Standard_Deviation (column:Column|Text|Integer) (new_name:Text|Nothing=Nothing) (population:Boolean=False)
## Creates a new column with the values concatenated together. `Nothing`

View File

@ -18,7 +18,6 @@ polyglot java import org.enso.table.error.ExistingDataException
polyglot java import org.enso.table.error.RangeExceededException
polyglot java import org.enso.table.error.InvalidLocationException
polyglot java import java.lang.IllegalArgumentException
polyglot java import java.lang.IllegalStateException
polyglot java import java.io.IOException
polyglot java import org.apache.poi.UnsupportedFileFormatException
@ -114,8 +113,7 @@ type Excel_Range
## Creates a Range from an address.
from_address : Text -> Excel_Range
from_address address =
illegal_argument caught_panic = Error.throw (Illegal_Argument_Error caught_panic.payload.cause.getMessage caught_panic.payload.cause)
Panic.catch IllegalArgumentException handler=illegal_argument <|
Illegal_Argument_Error.handle_java_exception <|
Excel_Range (Java_Range.new address)
## Create a Range for a single cell.
@ -281,15 +279,11 @@ handle_writer ~writer =
throw_existing_data caught_panic = Error.throw (Existing_Data caught_panic.payload.cause.getMessage)
handle_existing_data = Panic.catch ExistingDataException handler=throw_existing_data
## Illegal argument can occur if appending in an invalid mode
illegal_argument caught_panic = Error.throw (Illegal_Argument_Error caught_panic.payload.cause.getMessage caught_panic.payload.cause)
handle_illegal_argument = Panic.catch IllegalArgumentException handler=illegal_argument
## Should be impossible - occurs if no fallback serializer is provided.
throw_illegal_state caught_panic = Panic.throw (Illegal_State_Error caught_panic.payload.cause.getMessage)
handle_illegal_state = Panic.catch IllegalStateException handler=throw_illegal_state
handle_illegal_state <| Column_Name_Mismatch.handle_java_exception <|
Column_Count_Mismatch.handle_java_exception <| handle_bad_location <|
handle_illegal_argument <| handle_range_exceeded <| handle_existing_data <|
Illegal_Argument_Error.handle_java_exception <| handle_range_exceeded <| handle_existing_data <|
writer

View File

@ -17,7 +17,6 @@ polyglot java import org.enso.table.parsing.problems.MismatchedQuote
polyglot java import org.enso.table.parsing.problems.AdditionalInvalidRows
polyglot java import org.enso.table.util.problems.DuplicateNames
polyglot java import org.enso.table.util.problems.InvalidNames
polyglot java import java.lang.IllegalArgumentException
polyglot java import java.io.IOException
polyglot java import com.univocity.parsers.common.TextParsingException
polyglot java import org.enso.base.Encoding_Utils
@ -91,7 +90,7 @@ read_stream format stream on_problems max_columns=default_max_columns related_fi
integer.
read_from_reader : Delimited -> Reader -> Problem_Behavior -> Integer -> Any
read_from_reader format java_reader on_problems max_columns=4096 =
handle_illegal_arguments <| handle_parsing_failure <| handle_parsing_exception <|
Illegal_Argument_Error.handle_java_exception <| handle_parsing_failure <| handle_parsing_exception <|
reader = prepare_delimited_reader java_reader format max_columns on_problems
result_with_problems = reader.read
parsing_problems = Vector.Vector (result_with_problems.problems) . map translate_reader_problem
@ -160,7 +159,7 @@ type Detected_File_Metadata
detect_metadata : File -> File_Format.Delimited -> Detected_Headers
detect_metadata file format =
on_problems = Ignore
result = handle_io_exception file <| handle_illegal_arguments <| handle_parsing_failure <| handle_parsing_exception <|
result = handle_io_exception file <| Illegal_Argument_Error.handle_java_exception <| handle_parsing_failure <| handle_parsing_exception <|
file.with_input_stream [File.Option.Read] stream->
stream.with_stream_decoder format.encoding on_problems java_reader->
## We use the default `max_columns` setting. If we want to be able to
@ -179,12 +178,6 @@ detect_metadata file format =
Detected_File_Metadata headers line_separator
result.catch File.File_Not_Found (_->(Detected_File_Metadata Nothing Nothing))
## PRIVATE
handle_illegal_arguments =
translate_illegal_argument caught_panic =
Error.throw (Illegal_Argument_Error caught_panic.payload.cause.getMessage)
Panic.catch IllegalArgumentException handler=translate_illegal_argument
## PRIVATE
handle_parsing_failure =
translate_parsing_failure caught_panic =

View File

@ -6,12 +6,10 @@ import Standard.Table.IO.File_Format
import Standard.Table.IO.Excel
import Standard.Table.Data.Table
import Standard.Table.Data.Column
import Standard.Table.Model
from Standard.Table.IO.Excel export Excel_Section, Excel_Range
export Standard.Table.Data.Column
export Standard.Table.Model
export Standard.Table.IO.File_Read
export Standard.Table.IO.File_Format

View File

@ -1,77 +0,0 @@
from Standard.Base import all
from Standard.Table import Column, Table
## Compute the linear regression between the x and y coordinates, returning a
vector containing the slope in the first position and the bias in the second.
Arguments:
- x_column: The name of the column in `self` containing the x values.
- y_column: The name of the column in `self` containing the y values.
If the columns don't match in length, it throws a `Fit_Error`.
> Example
Compute the linear regression between two columns in a table.
from Standard.Table import all
example_linear_regression =
column_x = Column.from_vector "x" [1, 2, 3, 4, 5]
column_y = Column.from_vector "y" [2, 4, 6, 8, 10]
table = Table.new [column_x, column_y]
table.linear_regression "x" "y"
Table.Table.linear_regression : Text -> Text -> Vector Number ! Fit_Error
Table.Table.linear_regression x_column y_column =
x_values = self.at x_column
y_values = self.at y_column
linear_regression x_values y_values
## Compute the linear regression between the x and y coordinates, returning a
vector containing the slope in the first position and the bias in the second.
Arguments:
- x_values: The column of x coordinate values for each coordinate pair.
- y_values: The column of y coordinate values for each coordinate pair.
If the columns don't match in length, it throws a `Fit_Error`.
> Example
Compute the linear regression between two columns.
from Standard.Table import all
example_linear_regression =
column_x = Column.from_vector "x" [1, 2, 3, 4, 5]
column_y = Column.from_vector "y" [2, 4, 6, 8, 10]
Model.linear_regression column_x column_y
linear_regression : Column -> Column -> Vector Number ! Fit_Error
linear_regression x_values y_values =
if x_values.length != y_values.length then Error.throw (Fit_Error "Columns have different lengths.") else
n = x_values.length
x_squared = x_values.map (^2)
x_y = x_values * y_values
slope_numerator = (n * x_y.sum) - (x_values.sum * y_values.sum)
slope_denominator = (n * x_squared.sum) - (x_values.sum ^ 2)
slope = slope_numerator / slope_denominator
bias_numerator = y_values.sum - (slope * x_values.sum)
bias = bias_numerator / n
[slope, bias]
## PRIVATE
An error thrown when the linear regression cannot be computed.
Arguments:
- message: The error message.
type Fit_Error message
## PRIVATE
Converts the `Fit_Error` to a human-readable representation.
Fit_Error.to_display_text : Text
Fit_Error.to_display_text = "Could not fit the model: " + self.message.to_text

View File

@ -22,6 +22,51 @@ public class CorrelationStatistics {
totalXY += x * y;
}
/*
* Count of non-null pairs of values.
*/
public long getCount() {
return count;
}
/*
* Sum of X values.
*/
public double getTotalX() {
return totalX;
}
/*
* Sum of Y values.
*/
public double getTotalY() {
return totalY;
}
/*
* Sum of X^2 values.
*/
public double getTotalXX() {
return totalXX;
}
/*
* Sum of X * Y values.
*/
public double getTotalXY() {
return totalXY;
}
/*
* Sum of Y^2 values.
*/
public double getTotalYY() {
return totalYY;
}
/*
* Compute the covariance of X and Y.
*/
public double covariance() {
if (count < 2) {
return Double.NaN;
@ -30,6 +75,9 @@ public class CorrelationStatistics {
return (totalXY - totalX * totalY / count) / count;
}
/*
* Compute the Pearson correlation between X and Y.
*/
public double pearsonCorrelation() {
if (count < 2) {
return Double.NaN;
@ -40,6 +88,9 @@ public class CorrelationStatistics {
return (count * totalXY - totalX * totalY) / (n_stdev_x * n_stdev_y);
}
/*
* Compute the R-Squared between X and Y (which equals the Pearson correlation ^ 2).
*/
public double rSquared() {
double correl = this.pearsonCorrelation();
return correl * correl;

View File

@ -0,0 +1,10 @@
package org.enso.base.statistics;
/*
A class for exceptions thrown when fitting a model.
*/
public class FitError extends Exception {
public FitError(String message) {
super(message);
}
}

View File

@ -0,0 +1,3 @@
package org.enso.base.statistics;
public record LinearModel(double slope, double intercept, double rSquared) {}

View File

@ -0,0 +1,58 @@
package org.enso.base.statistics;
public class Regression {
/**
* Performs a least squares fit of a line to the data.
*
* @param known_xs Set of known X values.
* @param known_ys Set of known Y values.
* @return A fitted linear model (y = Intercept + Slope x) and the r-squared value.
* @throws IllegalArgumentException if the number of elements in the arrays is different or a
* singular X value is provided.
*/
public static LinearModel fit_linear(Double[] known_xs, Double[] known_ys)
throws IllegalArgumentException, FitError {
CorrelationStatistics stats = CorrelationStatistics.compute(known_xs, known_ys);
double denominator = denominator(stats);
if (denominator == 0) {
throw new FitError("Singular X value.");
}
double slope = slope(stats, denominator);
return new LinearModel(slope, intercept(stats, slope), stats.rSquared());
}
/**
* Performs a least squares fit of a line to the data with a given intercept.
*
* @param known_xs Set of known X values.
* @param known_ys Set of known Y values.
* @param intercept The intercept of the line.
* @return A fitted linear model (y = Intercept + Slope x) and the r-squared value.
* @throws IllegalArgumentException if the number of elements in the arrays is different or a
* singular X value is provided.
*/
public static LinearModel fit_linear(Double[] known_xs, Double[] known_ys, double intercept)
throws IllegalArgumentException {
CorrelationStatistics stats = CorrelationStatistics.compute(known_xs, known_ys);
return new LinearModel(slopeWithIntercept(stats, intercept), intercept, stats.rSquared());
}
private static double denominator(CorrelationStatistics stats) {
return stats.getTotalXX() - stats.getTotalX() * stats.getTotalX() / stats.getCount();
}
private static double slope(CorrelationStatistics stats, double denominator) {
return (stats.getTotalXY() - stats.getTotalX() * stats.getTotalY() / stats.getCount())
/ denominator;
}
private static double slopeWithIntercept(CorrelationStatistics stats, double intercept) {
return (-intercept * stats.getTotalX() + stats.getTotalXY()) / stats.getTotalXX();
}
private static double intercept(CorrelationStatistics stats, double slope) {
return (stats.getTotalY() - stats.getTotalX() * slope) / stats.getCount();
}
}

View File

@ -2,7 +2,6 @@ from Standard.Base import all
import Standard.Test
import project.Model_Spec
import project.Column_Spec
import project.Csv_Spec
import project.Delimited_Read_Spec
@ -21,7 +20,6 @@ in_memory_spec =
Excel_Spec.spec
Json_Spec.spec
Table_Spec.spec
Model_Spec.spec
Aggregate_Column_Spec.spec
Aggregate_Spec.spec

View File

@ -1,29 +0,0 @@
from Standard.Base import all
from Standard.Table import all
from Standard.Table.Model import Fit_Error
import Standard.Test
spec =
Test.group "Linear regression" <|
column_x = Column.from_vector "x" [2, 3, 5, 7, 9]
column_y = Column.from_vector "y" [4, 5, 7, 10, 15]
column_y_2 = Column.from_vector "y" [4, 5, 7, 10]
table = Table.new [column_x, column_y]
Test.specify "return an error if the column lengths do not match" <|
result = Model.linear_regression column_x column_y_2
result . should_fail_with Fit_Error
Test.specify "compute the linear least squares" <|
result = Model.linear_regression column_x column_y
result.length . should_equal 2
result.at 0 . should_equal epsilon=0.001 1.518
result.at 1 . should_equal epsilon=0.001 0.304
Test.specify "compute based on columns in a table" <|
result = table.linear_regression "x" "y"
result . should_equal (Model.linear_regression column_x column_y)

View File

@ -0,0 +1,108 @@
from Standard.Base import Nothing, Vector, Number, Decimal, True, Illegal_Argument_Error, False, Regression
import Standard.Test
spec =
## Regression test data produced using an Excel spreadsheet.
https://github.com/enso-org/enso/files/9160145/Regression.tests.xlsx
double_error = 0.000001
vector_compare values expected =
values.zip expected v->e->
case v of
Decimal -> v.should_equal e epsilon=double_error
_ -> v.should_equal e
Test.group "Regression" <|
Test.specify "return an error if the vector lengths do not match" <|
known_xs = [2, 3, 5, 7, 9]
known_ys = [4, 5, 7, 10]
Regression.fit_least_squares known_xs known_ys . should_fail_with Illegal_Argument_Error
Test.specify "return an error if the X values are all the same" <|
known_xs = [2, 2, 2, 2]
known_ys = [4, 5, 7, 10]
Regression.fit_least_squares known_xs known_ys . should_fail_with Regression.Fit_Error
Test.specify "compute the linear trend line" <|
known_xs = [2, 3, 5, 7, 9]
known_ys = [4, 5, 7, 10, 15]
fitted = Regression.fit_least_squares known_xs known_ys
fitted.slope . should_equal 1.518292683 epsilon=double_error
fitted.intercept . should_equal 0.304878049 epsilon=double_error
fitted.r_squared . should_equal 0.959530147 epsilon=double_error
Test.specify "predict values on a linear trend line" <|
known_xs = [2, 3, 5, 7, 9]
known_ys = [4, 5, 7, 10, 15]
fitted = Regression.fit_least_squares known_xs known_ys
test_xs = [1, 4, 6, 8, 10]
expected_ys = [1.823171, 6.378049, 9.414634, 12.45122, 15.487805]
vector_compare (test_xs.map fitted.predict) expected_ys
Test.specify "compute the linear trend line with an intercept" <|
known_xs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
known_ys = [8.02128, 11.02421, 13.99566, 17.02678, 20.00486, 22.95283, 26.0143, 29.03238, 31.96427, 35.03896]
fitted = Regression.fit_least_squares known_xs known_ys (Regression.Linear_Model 100)
fitted.slope . should_equal -10.57056558 epsilon=double_error
fitted.intercept . should_equal 100.0 epsilon=double_error
fitted.r_squared . should_equal 0.9999900045 epsilon=double_error
Test.specify "compute the exponential trend line" <|
known_xs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
known_ys = [0.28652, 0.31735, 0.31963, 0.38482, 0.40056, 0.39013, 0.4976, 0.5665, 0.55457, 0.69135]
fitted = Regression.fit_least_squares known_xs known_ys Regression.Exponential_Model
fitted.a . should_equal 0.25356436 epsilon=double_error
fitted.b . should_equal 0.09358242 epsilon=double_error
fitted.r_squared . should_equal 0.9506293649 epsilon=double_error
Test.specify "predict values on a exponential trend line" <|
known_xs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
known_ys = [0.28652, 0.31735, 0.31963, 0.38482, 0.40056, 0.39013, 0.4976, 0.5665, 0.55457, 0.69135]
fitted = Regression.fit_least_squares known_xs known_ys Regression.Exponential_Model
test_xs = [0, 11, 12, 15]
expected_ys = [0.253564, 0.709829, 0.779464, 1.032103]
vector_compare (test_xs.map fitted.predict) expected_ys
Test.specify "compute the exponential trend line with an intercept" <|
known_xs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
known_ys = [0.28652, 0.31735, 0.31963, 0.38482, 0.40056, 0.39013, 0.4976, 0.5665, 0.55457, 0.69135]
fitted = Regression.fit_least_squares known_xs known_ys (Regression.Exponential_Model 0.2)
fitted.a . should_equal 0.2 epsilon=double_error
fitted.b . should_equal 0.127482464 epsilon=double_error
fitted.r_squared . should_equal 0.9566066546 epsilon=double_error
Test.specify "compute the logarithmic trend line" <|
known_xs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
known_ys = [0.12128, 0.29057, 0.35933, 0.45949, 0.49113, 0.48285, 0.58132, 0.63144, 0.5916, 0.69158]
fitted = Regression.fit_least_squares known_xs known_ys Regression.Logarithmic_Model
fitted.a . should_equal 0.232702284 epsilon=double_error
fitted.b . should_equal 0.11857587 epsilon=double_error
fitted.r_squared . should_equal 0.9730840179 epsilon=double_error
Test.specify "predict values on a logarithmic trend line" <|
known_xs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
known_ys = [0.12128, 0.29057, 0.35933, 0.45949, 0.49113, 0.48285, 0.58132, 0.63144, 0.5916, 0.69158]
fitted = Regression.fit_least_squares known_xs known_ys Regression.Logarithmic_Model
test_xs = [0.1, 11, 12, 15]
expected_ys = [-0.417241, 0.676572, 0.696819, 0.748745]
vector_compare (test_xs.map fitted.predict) expected_ys
Test.specify "compute the power trend line" <|
known_xs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
known_ys = [0.26128, 0.28144, 0.26353, 0.30247, 0.28677, 0.23992, 0.30586, 0.32785, 0.26324, 0.3411]
fitted = Regression.fit_least_squares known_xs known_ys Regression.Power_Model
fitted.a . should_equal 0.258838019 epsilon=double_error
fitted.b . should_equal 0.065513849 epsilon=double_error
fitted.r_squared . should_equal 0.2099579581 epsilon=double_error
Test.specify "predict values on a power trend line" <|
known_xs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
known_ys = [0.26128, 0.28144, 0.26353, 0.30247, 0.28677, 0.23992, 0.30586, 0.32785, 0.26324, 0.3411]
fitted = Regression.fit_least_squares known_xs known_ys Regression.Power_Model
test_xs = [0.1, 11, 12, 15]
expected_ys = [0.222594, 0.302868, 0.3046, 0.309085]
vector_compare (test_xs.map fitted.predict) expected_ys
main = Test.Suite.run_main spec

View File

@ -40,6 +40,7 @@ import project.Data.Text_Spec
import project.Data.Time.Spec as Time_Spec
import project.Data.Vector_Spec
import project.Data.Statistics_Spec
import project.Data.Regression_Spec
import project.Data.Text.Regex_Spec
import project.Data.Text.Utils_Spec
import project.Data.Text.Default_Regex_Engine_Spec
@ -118,5 +119,6 @@ main = Test.Suite.run_main <|
Uri_Spec.spec
Vector_Spec.spec
Statistics_Spec.spec
Regression_Spec.spec
Warnings_Spec.spec
System_Spec.spec