OLSMultipleLinearRegression.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      https://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. /*
  18.  * This is not the original file distributed by the Apache Software Foundation
  19.  * It has been modified by the Hipparchus project
  20.  */
  21. package org.hipparchus.stat.regression;

  22. import org.hipparchus.exception.MathIllegalArgumentException;
  23. import org.hipparchus.linear.Array2DRowRealMatrix;
  24. import org.hipparchus.linear.LUDecomposition;
  25. import org.hipparchus.linear.QRDecomposition;
  26. import org.hipparchus.linear.RealMatrix;
  27. import org.hipparchus.linear.RealVector;
  28. import org.hipparchus.stat.StatUtils;
  29. import org.hipparchus.stat.descriptive.moment.SecondMoment;

  30. /**
  31.  * <p>Implements ordinary least squares (OLS) to estimate the parameters of a
  32.  * multiple linear regression model.</p>
  33.  *
  34.  * <p>The regression coefficients, <code>b</code>, satisfy the normal equations:</p>
  35.  * <pre><code> X<sup>T</sup> X b = X<sup>T</sup> y </code></pre>
  36.  *
  37.  * <p>
  38.  * To solve the normal equations, this implementation uses QR decomposition
  39.  * of the <code>X</code> matrix. (See {@link QRDecomposition} for details on the
  40.  * decomposition algorithm.) The <code>X</code> matrix, also known as the <i>design matrix,</i>
  41.  * has rows corresponding to sample observations and columns corresponding to independent
  42.  * variables.  When the model is estimated using an intercept term (i.e. when
  43.  * {@link #isNoIntercept() isNoIntercept} is false as it is by default), the <code>X</code>
  44.  * matrix includes an initial column identically equal to 1.  We solve the normal equations
  45.  * as follows:
  46.  * </p>
  47.  * <pre><code> X<sup>T</sup>X b = X<sup>T</sup> y
  48.  * (QR)<sup>T</sup> (QR) b = (QR)<sup>T</sup>y
  49.  * R<sup>T</sup> (Q<sup>T</sup>Q) R b = R<sup>T</sup> Q<sup>T</sup> y
  50.  * R<sup>T</sup> R b = R<sup>T</sup> Q<sup>T</sup> y
  51.  * (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> R b = (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> Q<sup>T</sup> y
  52.  * R b = Q<sup>T</sup> y </code></pre>
  53.  *
  54.  * <p>Given <code>Q</code> and <code>R</code>, the last equation is solved by back-substitution.</p>
  55.  *
  56.  */
  57. public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegression {

  58.     /** Cached QR decomposition of X matrix */
  59.     private QRDecomposition qr;

  60.     /** Singularity threshold for QR decomposition */
  61.     private final double threshold;

  62.     /**
  63.      * Create an empty OLSMultipleLinearRegression instance.
  64.      */
  65.     public OLSMultipleLinearRegression() {
  66.         this(0d);
  67.     }

  68.     /**
  69.      * Create an empty OLSMultipleLinearRegression instance, using the given
  70.      * singularity threshold for the QR decomposition.
  71.      *
  72.      * @param threshold the singularity threshold
  73.      */
  74.     public OLSMultipleLinearRegression(final double threshold) {
  75.         this.threshold = threshold;
  76.     }

  77.     /**
  78.      * Loads model x and y sample data, overriding any previous sample.
  79.      *
  80.      * Computes and caches QR decomposition of the X matrix.
  81.      * @param y the [n,1] array representing the y sample
  82.      * @param x the [n,k] array representing the x sample
  83.      * @throws MathIllegalArgumentException if the x and y array data are not
  84.      *             compatible for the regression
  85.      */
  86.     public void newSampleData(double[] y, double[][] x) throws MathIllegalArgumentException {
  87.         validateSampleData(x, y);
  88.         newYSampleData(y);
  89.         newXSampleData(x);
  90.     }

  91.     /**
  92.      * {@inheritDoc}
  93.      * <p>This implementation computes and caches the QR decomposition of the X matrix.</p>
  94.      */
  95.     @Override
  96.     public void newSampleData(double[] data, int nobs, int nvars) {
  97.         super.newSampleData(data, nobs, nvars);
  98.         qr = new QRDecomposition(getX(), threshold);
  99.     }

  100.     /**
  101.      * <p>Compute the "hat" matrix.
  102.      * </p>
  103.      * <p>The hat matrix is defined in terms of the design matrix X
  104.      *  by X(X<sup>T</sup>X)<sup>-1</sup>X<sup>T</sup>
  105.      * </p>
  106.      * <p>The implementation here uses the QR decomposition to compute the
  107.      * hat matrix as Q I<sub>p</sub>Q<sup>T</sup> where I<sub>p</sub> is the
  108.      * p-dimensional identity matrix augmented by 0's.  This computational
  109.      * formula is from "The Hat Matrix in Regression and ANOVA",
  110.      * David C. Hoaglin and Roy E. Welsch,
  111.      * <i>The American Statistician</i>, Vol. 32, No. 1 (Feb., 1978), pp. 17-22.
  112.      * </p>
  113.      * <p>Data for the model must have been successfully loaded using one of
  114.      * the {@code newSampleData} methods before invoking this method; otherwise
  115.      * a {@code NullPointerException} will be thrown.</p>
  116.      *
  117.      * @return the hat matrix
  118.      * @throws NullPointerException unless method {@code newSampleData} has been
  119.      * called beforehand.
  120.      */
  121.     public RealMatrix calculateHat() {
  122.         // Create augmented identity matrix
  123.         RealMatrix Q = qr.getQ();
  124.         final int p = qr.getR().getColumnDimension();
  125.         final int n = Q.getColumnDimension();
  126.         // No try-catch or advertised MathIllegalArgumentException - NPE above if n < 3
  127.         Array2DRowRealMatrix augI = new Array2DRowRealMatrix(n, n);
  128.         double[][] augIData = augI.getDataRef();
  129.         for (int i = 0; i < n; i++) {
  130.             for (int j =0; j < n; j++) {
  131.                 if (i == j && i < p) {
  132.                     augIData[i][j] = 1d;
  133.                 } else {
  134.                     augIData[i][j] = 0d;
  135.                 }
  136.             }
  137.         }

  138.         // Compute and return Hat matrix
  139.         // No DME advertised - args valid if we get here
  140.         return Q.multiply(augI).multiplyTransposed(Q);
  141.     }

  142.     /**
  143.      * <p>Returns the sum of squared deviations of Y from its mean.</p>
  144.      *
  145.      * <p>If the model has no intercept term, <code>0</code> is used for the
  146.      * mean of Y - i.e., what is returned is the sum of the squared Y values.</p>
  147.      *
  148.      * <p>The value returned by this method is the SSTO value used in
  149.      * the {@link #calculateRSquared() R-squared} computation.</p>
  150.      *
  151.      * @return SSTO - the total sum of squares
  152.      * @throws NullPointerException if the sample has not been set
  153.      * @see #isNoIntercept()
  154.      */
  155.     public double calculateTotalSumOfSquares() {
  156.         if (isNoIntercept()) {
  157.             return StatUtils.sumSq(getY().toArray());
  158.         } else {
  159.             return new SecondMoment().evaluate(getY().toArray());
  160.         }
  161.     }

  162.     /**
  163.      * Returns the sum of squared residuals.
  164.      *
  165.      * @return residual sum of squares
  166.      * @throws org.hipparchus.exception.MathIllegalArgumentException if the design matrix is singular
  167.      * @throws NullPointerException if the data for the model have not been loaded
  168.      */
  169.     public double calculateResidualSumOfSquares() {
  170.         final RealVector residuals = calculateResiduals();
  171.         // No advertised DME, args are valid
  172.         return residuals.dotProduct(residuals);
  173.     }

  174.     /**
  175.      * Returns the R-Squared statistic, defined by the formula \(R^2 = 1 - \frac{\mathrm{SSR}}{\mathrm{SSTO}}\)
  176.      * where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals}
  177.      * and SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares}
  178.      *
  179.      * <p>If there is no variance in y, i.e., SSTO = 0, NaN is returned.</p>
  180.      *
  181.      * @return R-square statistic
  182.      * @throws NullPointerException if the sample has not been set
  183.      * @throws org.hipparchus.exception.MathIllegalArgumentException if the design matrix is singular
  184.      */
  185.     public double calculateRSquared() {
  186.         return 1 - calculateResidualSumOfSquares() / calculateTotalSumOfSquares();
  187.     }

  188.     /**
  189.      * <p>Returns the adjusted R-squared statistic, defined by the formula
  190.      * \(R_\mathrm{adj}^2 = 1 - \frac{\mathrm{SSR} (n - 1)}{\mathrm{SSTO} (n - p)}\)
  191.      * where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals},
  192.      * SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares}, n is the number
  193.      * of observations and p is the number of parameters estimated (including the intercept).</p>
  194.      *
  195.      * <p>If the regression is estimated without an intercept term, what is returned is </p><pre>
  196.      * <code> 1 - (1 - {@link #calculateRSquared()}) * (n / (n - p)) </code>
  197.      * </pre>
  198.      *
  199.      * <p>If there is no variance in y, i.e., SSTO = 0, NaN is returned.</p>
  200.      *
  201.      * @return adjusted R-Squared statistic
  202.      * @throws NullPointerException if the sample has not been set
  203.      * @throws org.hipparchus.exception.MathIllegalArgumentException if the design matrix is singular
  204.      * @see #isNoIntercept()
  205.      */
  206.     public double calculateAdjustedRSquared() {
  207.         final double n = getX().getRowDimension();
  208.         if (isNoIntercept()) {
  209.             return 1 - (1 - calculateRSquared()) * (n / (n - getX().getColumnDimension()));
  210.         } else {
  211.             return 1 - (calculateResidualSumOfSquares() * (n - 1)) /
  212.                 (calculateTotalSumOfSquares() * (n - getX().getColumnDimension()));
  213.         }
  214.     }

  215.     /**
  216.      * {@inheritDoc}
  217.      * <p>This implementation computes and caches the QR decomposition of the X matrix
  218.      * once it is successfully loaded.</p>
  219.      */
  220.     @Override
  221.     protected void newXSampleData(double[][] x) {
  222.         super.newXSampleData(x);
  223.         qr = new QRDecomposition(getX(), threshold);
  224.     }

  225.     /**
  226.      * Calculates the regression coefficients using OLS.
  227.      *
  228.      * <p>Data for the model must have been successfully loaded using one of
  229.      * the {@code newSampleData} methods before invoking this method; otherwise
  230.      * a {@code NullPointerException} will be thrown.</p>
  231.      *
  232.      * @return beta
  233.      * @throws org.hipparchus.exception.MathIllegalArgumentException if the design matrix is singular
  234.      * @throws NullPointerException if the data for the model have not been loaded
  235.      */
  236.     @Override
  237.     protected RealVector calculateBeta() {
  238.         return qr.getSolver().solve(getY());
  239.     }

  240.     /**
  241.      * <p>Calculates the variance-covariance matrix of the regression parameters.
  242.      * </p>
  243.      * <p>Var(b) = (X<sup>T</sup>X)<sup>-1</sup>
  244.      * </p>
  245.      * <p>Uses QR decomposition to reduce (X<sup>T</sup>X)<sup>-1</sup>
  246.      * to (R<sup>T</sup>R)<sup>-1</sup>, with only the top p rows of
  247.      * R included, where p = the length of the beta vector.</p>
  248.      *
  249.      * <p>Data for the model must have been successfully loaded using one of
  250.      * the {@code newSampleData} methods before invoking this method; otherwise
  251.      * a {@code NullPointerException} will be thrown.</p>
  252.      *
  253.      * @return The beta variance-covariance matrix
  254.      * @throws org.hipparchus.exception.MathIllegalArgumentException if the design matrix is singular
  255.      * @throws NullPointerException if the data for the model have not been loaded
  256.      */
  257.     @Override
  258.     protected RealMatrix calculateBetaVariance() {
  259.         int p = getX().getColumnDimension();
  260.         RealMatrix Raug = qr.getR().getSubMatrix(0, p - 1 , 0, p - 1);
  261.         RealMatrix Rinv = new LUDecomposition(Raug).getSolver().getInverse();
  262.         return Rinv.multiplyTransposed(Rinv);
  263.     }

  264. }