View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  /*
19   * This is not the original file distributed by the Apache Software Foundation
20   * It has been modified by the Hipparchus project
21   */
22  package org.hipparchus.stat.regression;
23  
24  import org.hipparchus.exception.MathIllegalArgumentException;
25  import org.hipparchus.linear.Array2DRowRealMatrix;
26  import org.hipparchus.linear.LUDecomposition;
27  import org.hipparchus.linear.QRDecomposition;
28  import org.hipparchus.linear.RealMatrix;
29  import org.hipparchus.linear.RealVector;
30  import org.hipparchus.stat.StatUtils;
31  import org.hipparchus.stat.descriptive.moment.SecondMoment;
32  
33  /**
34   * <p>Implements ordinary least squares (OLS) to estimate the parameters of a
35   * multiple linear regression model.</p>
36   *
37   * <p>The regression coefficients, <code>b</code>, satisfy the normal equations:</p>
38   * <pre><code> X<sup>T</sup> X b = X<sup>T</sup> y </code></pre>
39   *
40   * <p>
41   * To solve the normal equations, this implementation uses QR decomposition
42   * of the <code>X</code> matrix. (See {@link QRDecomposition} for details on the
43   * decomposition algorithm.) The <code>X</code> matrix, also known as the <i>design matrix,</i>
44   * has rows corresponding to sample observations and columns corresponding to independent
45   * variables.  When the model is estimated using an intercept term (i.e. when
46   * {@link #isNoIntercept() isNoIntercept} is false as it is by default), the <code>X</code>
47   * matrix includes an initial column identically equal to 1.  We solve the normal equations
48   * as follows:
49   * </p>
50   * <pre><code> X<sup>T</sup>X b = X<sup>T</sup> y
51   * (QR)<sup>T</sup> (QR) b = (QR)<sup>T</sup>y
52   * R<sup>T</sup> (Q<sup>T</sup>Q) R b = R<sup>T</sup> Q<sup>T</sup> y
53   * R<sup>T</sup> R b = R<sup>T</sup> Q<sup>T</sup> y
54   * (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> R b = (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> Q<sup>T</sup> y
55   * R b = Q<sup>T</sup> y </code></pre>
56   *
57   * <p>Given <code>Q</code> and <code>R</code>, the last equation is solved by back-substitution.</p>
58   *
59   */
60  public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegression {
61  
62      /** Cached QR decomposition of X matrix */
63      private QRDecomposition qr;
64  
65      /** Singularity threshold for QR decomposition */
66      private final double threshold;
67  
68      /**
69       * Create an empty OLSMultipleLinearRegression instance.
70       */
71      public OLSMultipleLinearRegression() {
72          this(0d);
73      }
74  
75      /**
76       * Create an empty OLSMultipleLinearRegression instance, using the given
77       * singularity threshold for the QR decomposition.
78       *
79       * @param threshold the singularity threshold
80       */
81      public OLSMultipleLinearRegression(final double threshold) {
82          this.threshold = threshold;
83      }
84  
85      /**
86       * Loads model x and y sample data, overriding any previous sample.
87       *
88       * Computes and caches QR decomposition of the X matrix.
89       * @param y the [n,1] array representing the y sample
90       * @param x the [n,k] array representing the x sample
91       * @throws MathIllegalArgumentException if the x and y array data are not
92       *             compatible for the regression
93       */
94      public void newSampleData(double[] y, double[][] x) throws MathIllegalArgumentException {
95          validateSampleData(x, y);
96          newYSampleData(y);
97          newXSampleData(x);
98      }
99  
100     /**
101      * {@inheritDoc}
102      * <p>This implementation computes and caches the QR decomposition of the X matrix.</p>
103      */
104     @Override
105     public void newSampleData(double[] data, int nobs, int nvars) {
106         super.newSampleData(data, nobs, nvars);
107         qr = new QRDecomposition(getX(), threshold);
108     }
109 
110     /**
111      * <p>Compute the "hat" matrix.
112      * </p>
113      * <p>The hat matrix is defined in terms of the design matrix X
114      *  by X(X<sup>T</sup>X)<sup>-1</sup>X<sup>T</sup>
115      * </p>
116      * <p>The implementation here uses the QR decomposition to compute the
117      * hat matrix as Q I<sub>p</sub>Q<sup>T</sup> where I<sub>p</sub> is the
118      * p-dimensional identity matrix augmented by 0's.  This computational
119      * formula is from "The Hat Matrix in Regression and ANOVA",
120      * David C. Hoaglin and Roy E. Welsch,
121      * <i>The American Statistician</i>, Vol. 32, No. 1 (Feb., 1978), pp. 17-22.
122      * </p>
123      * <p>Data for the model must have been successfully loaded using one of
124      * the {@code newSampleData} methods before invoking this method; otherwise
125      * a {@code NullPointerException} will be thrown.</p>
126      *
127      * @return the hat matrix
128      * @throws NullPointerException unless method {@code newSampleData} has been
129      * called beforehand.
130      */
131     public RealMatrix calculateHat() {
132         // Create augmented identity matrix
133         RealMatrix Q = qr.getQ();
134         final int p = qr.getR().getColumnDimension();
135         final int n = Q.getColumnDimension();
136         // No try-catch or advertised MathIllegalArgumentException - NPE above if n < 3
137         Array2DRowRealMatrix augI = new Array2DRowRealMatrix(n, n);
138         double[][] augIData = augI.getDataRef();
139         for (int i = 0; i < n; i++) {
140             for (int j =0; j < n; j++) {
141                 if (i == j && i < p) {
142                     augIData[i][j] = 1d;
143                 } else {
144                     augIData[i][j] = 0d;
145                 }
146             }
147         }
148 
149         // Compute and return Hat matrix
150         // No DME advertised - args valid if we get here
151         return Q.multiply(augI).multiplyTransposed(Q);
152     }
153 
154     /**
155      * <p>Returns the sum of squared deviations of Y from its mean.</p>
156      *
157      * <p>If the model has no intercept term, <code>0</code> is used for the
158      * mean of Y - i.e., what is returned is the sum of the squared Y values.</p>
159      *
160      * <p>The value returned by this method is the SSTO value used in
161      * the {@link #calculateRSquared() R-squared} computation.</p>
162      *
163      * @return SSTO - the total sum of squares
164      * @throws NullPointerException if the sample has not been set
165      * @see #isNoIntercept()
166      */
167     public double calculateTotalSumOfSquares() {
168         if (isNoIntercept()) {
169             return StatUtils.sumSq(getY().toArray());
170         } else {
171             return new SecondMoment().evaluate(getY().toArray());
172         }
173     }
174 
175     /**
176      * Returns the sum of squared residuals.
177      *
178      * @return residual sum of squares
179      * @throws org.hipparchus.exception.MathIllegalArgumentException if the design matrix is singular
180      * @throws NullPointerException if the data for the model have not been loaded
181      */
182     public double calculateResidualSumOfSquares() {
183         final RealVector residuals = calculateResiduals();
184         // No advertised DME, args are valid
185         return residuals.dotProduct(residuals);
186     }
187 
188     /**
189      * Returns the R-Squared statistic, defined by the formula \(R^2 = 1 - \frac{\mathrm{SSR}}{\mathrm{SSTO}}\)
190      * where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals}
191      * and SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares}
192      *
193      * <p>If there is no variance in y, i.e., SSTO = 0, NaN is returned.</p>
194      *
195      * @return R-square statistic
196      * @throws NullPointerException if the sample has not been set
197      * @throws org.hipparchus.exception.MathIllegalArgumentException if the design matrix is singular
198      */
199     public double calculateRSquared() {
200         return 1 - calculateResidualSumOfSquares() / calculateTotalSumOfSquares();
201     }
202 
203     /**
204      * <p>Returns the adjusted R-squared statistic, defined by the formula
205      * \(R_\mathrm{adj}^2 = 1 - \frac{\mathrm{SSR} (n - 1)}{\mathrm{SSTO} (n - p)}\)
206      * where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals},
207      * SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares}, n is the number
208      * of observations and p is the number of parameters estimated (including the intercept).</p>
209      *
210      * <p>If the regression is estimated without an intercept term, what is returned is </p><pre>
211      * <code> 1 - (1 - {@link #calculateRSquared()}) * (n / (n - p)) </code>
212      * </pre>
213      *
214      * <p>If there is no variance in y, i.e., SSTO = 0, NaN is returned.</p>
215      *
216      * @return adjusted R-Squared statistic
217      * @throws NullPointerException if the sample has not been set
218      * @throws org.hipparchus.exception.MathIllegalArgumentException if the design matrix is singular
219      * @see #isNoIntercept()
220      */
221     public double calculateAdjustedRSquared() {
222         final double n = getX().getRowDimension();
223         if (isNoIntercept()) {
224             return 1 - (1 - calculateRSquared()) * (n / (n - getX().getColumnDimension()));
225         } else {
226             return 1 - (calculateResidualSumOfSquares() * (n - 1)) /
227                 (calculateTotalSumOfSquares() * (n - getX().getColumnDimension()));
228         }
229     }
230 
231     /**
232      * {@inheritDoc}
233      * <p>This implementation computes and caches the QR decomposition of the X matrix
234      * once it is successfully loaded.</p>
235      */
236     @Override
237     protected void newXSampleData(double[][] x) {
238         super.newXSampleData(x);
239         qr = new QRDecomposition(getX(), threshold);
240     }
241 
242     /**
243      * Calculates the regression coefficients using OLS.
244      *
245      * <p>Data for the model must have been successfully loaded using one of
246      * the {@code newSampleData} methods before invoking this method; otherwise
247      * a {@code NullPointerException} will be thrown.</p>
248      *
249      * @return beta
250      * @throws org.hipparchus.exception.MathIllegalArgumentException if the design matrix is singular
251      * @throws NullPointerException if the data for the model have not been loaded
252      */
253     @Override
254     protected RealVector calculateBeta() {
255         return qr.getSolver().solve(getY());
256     }
257 
258     /**
259      * <p>Calculates the variance-covariance matrix of the regression parameters.
260      * </p>
261      * <p>Var(b) = (X<sup>T</sup>X)<sup>-1</sup>
262      * </p>
263      * <p>Uses QR decomposition to reduce (X<sup>T</sup>X)<sup>-1</sup>
264      * to (R<sup>T</sup>R)<sup>-1</sup>, with only the top p rows of
265      * R included, where p = the length of the beta vector.</p>
266      *
267      * <p>Data for the model must have been successfully loaded using one of
268      * the {@code newSampleData} methods before invoking this method; otherwise
269      * a {@code NullPointerException} will be thrown.</p>
270      *
271      * @return The beta variance-covariance matrix
272      * @throws org.hipparchus.exception.MathIllegalArgumentException if the design matrix is singular
273      * @throws NullPointerException if the data for the model have not been loaded
274      */
275     @Override
276     protected RealMatrix calculateBetaVariance() {
277         int p = getX().getColumnDimension();
278         RealMatrix Raug = qr.getR().getSubMatrix(0, p - 1 , 0, p - 1);
279         RealMatrix Rinv = new LUDecomposition(Raug).getSolver().getInverse();
280         return Rinv.multiplyTransposed(Rinv);
281     }
282 
283 }