OLSMultipleLinearRegression.java
- /*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- /*
- * This is not the original file distributed by the Apache Software Foundation
- * It has been modified by the Hipparchus project
- */
- package org.hipparchus.stat.regression;
- import org.hipparchus.exception.MathIllegalArgumentException;
- import org.hipparchus.linear.Array2DRowRealMatrix;
- import org.hipparchus.linear.LUDecomposition;
- import org.hipparchus.linear.QRDecomposition;
- import org.hipparchus.linear.RealMatrix;
- import org.hipparchus.linear.RealVector;
- import org.hipparchus.stat.StatUtils;
- import org.hipparchus.stat.descriptive.moment.SecondMoment;
- /**
- * <p>Implements ordinary least squares (OLS) to estimate the parameters of a
- * multiple linear regression model.</p>
- *
- * <p>The regression coefficients, <code>b</code>, satisfy the normal equations:</p>
- * <pre><code> X<sup>T</sup> X b = X<sup>T</sup> y </code></pre>
- *
- * <p>
- * To solve the normal equations, this implementation uses QR decomposition
- * of the <code>X</code> matrix. (See {@link QRDecomposition} for details on the
- * decomposition algorithm.) The <code>X</code> matrix, also known as the <i>design matrix,</i>
- * has rows corresponding to sample observations and columns corresponding to independent
- * variables. When the model is estimated using an intercept term (i.e. when
- * {@link #isNoIntercept() isNoIntercept} is false as it is by default), the <code>X</code>
- * matrix includes an initial column identically equal to 1. We solve the normal equations
- * as follows:
- * </p>
- * <pre><code> X<sup>T</sup>X b = X<sup>T</sup> y
- * (QR)<sup>T</sup> (QR) b = (QR)<sup>T</sup>y
- * R<sup>T</sup> (Q<sup>T</sup>Q) R b = R<sup>T</sup> Q<sup>T</sup> y
- * R<sup>T</sup> R b = R<sup>T</sup> Q<sup>T</sup> y
- * (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> R b = (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> Q<sup>T</sup> y
- * R b = Q<sup>T</sup> y </code></pre>
- *
- * <p>Given <code>Q</code> and <code>R</code>, the last equation is solved by back-substitution.</p>
- *
- */
- public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegression {
- /** Cached QR decomposition of X matrix */
- private QRDecomposition qr;
- /** Singularity threshold for QR decomposition */
- private final double threshold;
- /**
- * Create an empty OLSMultipleLinearRegression instance.
- */
- public OLSMultipleLinearRegression() {
- this(0d);
- }
- /**
- * Create an empty OLSMultipleLinearRegression instance, using the given
- * singularity threshold for the QR decomposition.
- *
- * @param threshold the singularity threshold
- */
- public OLSMultipleLinearRegression(final double threshold) {
- this.threshold = threshold;
- }
- /**
- * Loads model x and y sample data, overriding any previous sample.
- *
- * Computes and caches QR decomposition of the X matrix.
- * @param y the [n,1] array representing the y sample
- * @param x the [n,k] array representing the x sample
- * @throws MathIllegalArgumentException if the x and y array data are not
- * compatible for the regression
- */
- public void newSampleData(double[] y, double[][] x) throws MathIllegalArgumentException {
- validateSampleData(x, y);
- newYSampleData(y);
- newXSampleData(x);
- }
- /**
- * {@inheritDoc}
- * <p>This implementation computes and caches the QR decomposition of the X matrix.</p>
- */
- @Override
- public void newSampleData(double[] data, int nobs, int nvars) {
- super.newSampleData(data, nobs, nvars);
- qr = new QRDecomposition(getX(), threshold);
- }
- /**
- * <p>Compute the "hat" matrix.
- * </p>
- * <p>The hat matrix is defined in terms of the design matrix X
- * by X(X<sup>T</sup>X)<sup>-1</sup>X<sup>T</sup>
- * </p>
- * <p>The implementation here uses the QR decomposition to compute the
- * hat matrix as Q I<sub>p</sub>Q<sup>T</sup> where I<sub>p</sub> is the
- * p-dimensional identity matrix augmented by 0's. This computational
- * formula is from "The Hat Matrix in Regression and ANOVA",
- * David C. Hoaglin and Roy E. Welsch,
- * <i>The American Statistician</i>, Vol. 32, No. 1 (Feb., 1978), pp. 17-22.
- * </p>
- * <p>Data for the model must have been successfully loaded using one of
- * the {@code newSampleData} methods before invoking this method; otherwise
- * a {@code NullPointerException} will be thrown.</p>
- *
- * @return the hat matrix
- * @throws NullPointerException unless method {@code newSampleData} has been
- * called beforehand.
- */
- public RealMatrix calculateHat() {
- // Create augmented identity matrix
- RealMatrix Q = qr.getQ();
- final int p = qr.getR().getColumnDimension();
- final int n = Q.getColumnDimension();
- // No try-catch or advertised MathIllegalArgumentException - NPE above if n < 3
- Array2DRowRealMatrix augI = new Array2DRowRealMatrix(n, n);
- double[][] augIData = augI.getDataRef();
- for (int i = 0; i < n; i++) {
- for (int j =0; j < n; j++) {
- if (i == j && i < p) {
- augIData[i][j] = 1d;
- } else {
- augIData[i][j] = 0d;
- }
- }
- }
- // Compute and return Hat matrix
- // No DME advertised - args valid if we get here
- return Q.multiply(augI).multiplyTransposed(Q);
- }
- /**
- * <p>Returns the sum of squared deviations of Y from its mean.</p>
- *
- * <p>If the model has no intercept term, <code>0</code> is used for the
- * mean of Y - i.e., what is returned is the sum of the squared Y values.</p>
- *
- * <p>The value returned by this method is the SSTO value used in
- * the {@link #calculateRSquared() R-squared} computation.</p>
- *
- * @return SSTO - the total sum of squares
- * @throws NullPointerException if the sample has not been set
- * @see #isNoIntercept()
- */
- public double calculateTotalSumOfSquares() {
- if (isNoIntercept()) {
- return StatUtils.sumSq(getY().toArray());
- } else {
- return new SecondMoment().evaluate(getY().toArray());
- }
- }
- /**
- * Returns the sum of squared residuals.
- *
- * @return residual sum of squares
- * @throws org.hipparchus.exception.MathIllegalArgumentException if the design matrix is singular
- * @throws NullPointerException if the data for the model have not been loaded
- */
- public double calculateResidualSumOfSquares() {
- final RealVector residuals = calculateResiduals();
- // No advertised DME, args are valid
- return residuals.dotProduct(residuals);
- }
- /**
- * Returns the R-Squared statistic, defined by the formula \(R^2 = 1 - \frac{\mathrm{SSR}}{\mathrm{SSTO}}\)
- * where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals}
- * and SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares}
- *
- * <p>If there is no variance in y, i.e., SSTO = 0, NaN is returned.</p>
- *
- * @return R-square statistic
- * @throws NullPointerException if the sample has not been set
- * @throws org.hipparchus.exception.MathIllegalArgumentException if the design matrix is singular
- */
- public double calculateRSquared() {
- return 1 - calculateResidualSumOfSquares() / calculateTotalSumOfSquares();
- }
- /**
- * <p>Returns the adjusted R-squared statistic, defined by the formula
- * \(R_\mathrm{adj}^2 = 1 - \frac{\mathrm{SSR} (n - 1)}{\mathrm{SSTO} (n - p)}\)
- * where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals},
- * SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares}, n is the number
- * of observations and p is the number of parameters estimated (including the intercept).</p>
- *
- * <p>If the regression is estimated without an intercept term, what is returned is </p><pre>
- * <code> 1 - (1 - {@link #calculateRSquared()}) * (n / (n - p)) </code>
- * </pre>
- *
- * <p>If there is no variance in y, i.e., SSTO = 0, NaN is returned.</p>
- *
- * @return adjusted R-Squared statistic
- * @throws NullPointerException if the sample has not been set
- * @throws org.hipparchus.exception.MathIllegalArgumentException if the design matrix is singular
- * @see #isNoIntercept()
- */
- public double calculateAdjustedRSquared() {
- final double n = getX().getRowDimension();
- if (isNoIntercept()) {
- return 1 - (1 - calculateRSquared()) * (n / (n - getX().getColumnDimension()));
- } else {
- return 1 - (calculateResidualSumOfSquares() * (n - 1)) /
- (calculateTotalSumOfSquares() * (n - getX().getColumnDimension()));
- }
- }
- /**
- * {@inheritDoc}
- * <p>This implementation computes and caches the QR decomposition of the X matrix
- * once it is successfully loaded.</p>
- */
- @Override
- protected void newXSampleData(double[][] x) {
- super.newXSampleData(x);
- qr = new QRDecomposition(getX(), threshold);
- }
- /**
- * Calculates the regression coefficients using OLS.
- *
- * <p>Data for the model must have been successfully loaded using one of
- * the {@code newSampleData} methods before invoking this method; otherwise
- * a {@code NullPointerException} will be thrown.</p>
- *
- * @return beta
- * @throws org.hipparchus.exception.MathIllegalArgumentException if the design matrix is singular
- * @throws NullPointerException if the data for the model have not been loaded
- */
- @Override
- protected RealVector calculateBeta() {
- return qr.getSolver().solve(getY());
- }
- /**
- * <p>Calculates the variance-covariance matrix of the regression parameters.
- * </p>
- * <p>Var(b) = (X<sup>T</sup>X)<sup>-1</sup>
- * </p>
- * <p>Uses QR decomposition to reduce (X<sup>T</sup>X)<sup>-1</sup>
- * to (R<sup>T</sup>R)<sup>-1</sup>, with only the top p rows of
- * R included, where p = the length of the beta vector.</p>
- *
- * <p>Data for the model must have been successfully loaded using one of
- * the {@code newSampleData} methods before invoking this method; otherwise
- * a {@code NullPointerException} will be thrown.</p>
- *
- * @return The beta variance-covariance matrix
- * @throws org.hipparchus.exception.MathIllegalArgumentException if the design matrix is singular
- * @throws NullPointerException if the data for the model have not been loaded
- */
- @Override
- protected RealMatrix calculateBetaVariance() {
- int p = getX().getColumnDimension();
- RealMatrix Raug = qr.getR().getSubMatrix(0, p - 1 , 0, p - 1);
- RealMatrix Rinv = new LUDecomposition(Raug).getSolver().getInverse();
- return Rinv.multiplyTransposed(Rinv);
- }
- }