1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * https://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 /*
19 * This is not the original file distributed by the Apache Software Foundation
20 * It has been modified by the Hipparchus project
21 */
22 package org.hipparchus.stat.regression;
23
24 import org.hipparchus.exception.MathIllegalArgumentException;
25 import org.hipparchus.linear.Array2DRowRealMatrix;
26 import org.hipparchus.linear.LUDecomposition;
27 import org.hipparchus.linear.QRDecomposition;
28 import org.hipparchus.linear.RealMatrix;
29 import org.hipparchus.linear.RealVector;
30 import org.hipparchus.stat.StatUtils;
31 import org.hipparchus.stat.descriptive.moment.SecondMoment;
32
33 /**
34 * <p>Implements ordinary least squares (OLS) to estimate the parameters of a
35 * multiple linear regression model.</p>
36 *
37 * <p>The regression coefficients, <code>b</code>, satisfy the normal equations:</p>
38 * <pre><code> X<sup>T</sup> X b = X<sup>T</sup> y </code></pre>
39 *
40 * <p>
41 * To solve the normal equations, this implementation uses QR decomposition
42 * of the <code>X</code> matrix. (See {@link QRDecomposition} for details on the
43 * decomposition algorithm.) The <code>X</code> matrix, also known as the <i>design matrix,</i>
44 * has rows corresponding to sample observations and columns corresponding to independent
45 * variables. When the model is estimated using an intercept term (i.e. when
46 * {@link #isNoIntercept() isNoIntercept} is false as it is by default), the <code>X</code>
47 * matrix includes an initial column identically equal to 1. We solve the normal equations
48 * as follows:
49 * </p>
50 * <pre><code> X<sup>T</sup>X b = X<sup>T</sup> y
51 * (QR)<sup>T</sup> (QR) b = (QR)<sup>T</sup>y
52 * R<sup>T</sup> (Q<sup>T</sup>Q) R b = R<sup>T</sup> Q<sup>T</sup> y
53 * R<sup>T</sup> R b = R<sup>T</sup> Q<sup>T</sup> y
54 * (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> R b = (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> Q<sup>T</sup> y
55 * R b = Q<sup>T</sup> y </code></pre>
56 *
57 * <p>Given <code>Q</code> and <code>R</code>, the last equation is solved by back-substitution.</p>
58 *
59 */
60 public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegression {
61
62 /** Cached QR decomposition of X matrix */
63 private QRDecomposition qr;
64
65 /** Singularity threshold for QR decomposition */
66 private final double threshold;
67
68 /**
69 * Create an empty OLSMultipleLinearRegression instance.
70 */
71 public OLSMultipleLinearRegression() {
72 this(0d);
73 }
74
75 /**
76 * Create an empty OLSMultipleLinearRegression instance, using the given
77 * singularity threshold for the QR decomposition.
78 *
79 * @param threshold the singularity threshold
80 */
81 public OLSMultipleLinearRegression(final double threshold) {
82 this.threshold = threshold;
83 }
84
85 /**
86 * Loads model x and y sample data, overriding any previous sample.
87 *
88 * Computes and caches QR decomposition of the X matrix.
89 * @param y the [n,1] array representing the y sample
90 * @param x the [n,k] array representing the x sample
91 * @throws MathIllegalArgumentException if the x and y array data are not
92 * compatible for the regression
93 */
94 public void newSampleData(double[] y, double[][] x) throws MathIllegalArgumentException {
95 validateSampleData(x, y);
96 newYSampleData(y);
97 newXSampleData(x);
98 }
99
100 /**
101 * {@inheritDoc}
102 * <p>This implementation computes and caches the QR decomposition of the X matrix.</p>
103 */
104 @Override
105 public void newSampleData(double[] data, int nobs, int nvars) {
106 super.newSampleData(data, nobs, nvars);
107 qr = new QRDecomposition(getX(), threshold);
108 }
109
110 /**
111 * <p>Compute the "hat" matrix.
112 * </p>
113 * <p>The hat matrix is defined in terms of the design matrix X
114 * by X(X<sup>T</sup>X)<sup>-1</sup>X<sup>T</sup>
115 * </p>
116 * <p>The implementation here uses the QR decomposition to compute the
117 * hat matrix as Q I<sub>p</sub>Q<sup>T</sup> where I<sub>p</sub> is the
118 * p-dimensional identity matrix augmented by 0's. This computational
119 * formula is from "The Hat Matrix in Regression and ANOVA",
120 * David C. Hoaglin and Roy E. Welsch,
121 * <i>The American Statistician</i>, Vol. 32, No. 1 (Feb., 1978), pp. 17-22.
122 * </p>
123 * <p>Data for the model must have been successfully loaded using one of
124 * the {@code newSampleData} methods before invoking this method; otherwise
125 * a {@code NullPointerException} will be thrown.</p>
126 *
127 * @return the hat matrix
128 * @throws NullPointerException unless method {@code newSampleData} has been
129 * called beforehand.
130 */
131 public RealMatrix calculateHat() {
132 // Create augmented identity matrix
133 RealMatrix Q = qr.getQ();
134 final int p = qr.getR().getColumnDimension();
135 final int n = Q.getColumnDimension();
136 // No try-catch or advertised MathIllegalArgumentException - NPE above if n < 3
137 Array2DRowRealMatrix augI = new Array2DRowRealMatrix(n, n);
138 double[][] augIData = augI.getDataRef();
139 for (int i = 0; i < n; i++) {
140 for (int j =0; j < n; j++) {
141 if (i == j && i < p) {
142 augIData[i][j] = 1d;
143 } else {
144 augIData[i][j] = 0d;
145 }
146 }
147 }
148
149 // Compute and return Hat matrix
150 // No DME advertised - args valid if we get here
151 return Q.multiply(augI).multiplyTransposed(Q);
152 }
153
154 /**
155 * <p>Returns the sum of squared deviations of Y from its mean.</p>
156 *
157 * <p>If the model has no intercept term, <code>0</code> is used for the
158 * mean of Y - i.e., what is returned is the sum of the squared Y values.</p>
159 *
160 * <p>The value returned by this method is the SSTO value used in
161 * the {@link #calculateRSquared() R-squared} computation.</p>
162 *
163 * @return SSTO - the total sum of squares
164 * @throws NullPointerException if the sample has not been set
165 * @see #isNoIntercept()
166 */
167 public double calculateTotalSumOfSquares() {
168 if (isNoIntercept()) {
169 return StatUtils.sumSq(getY().toArray());
170 } else {
171 return new SecondMoment().evaluate(getY().toArray());
172 }
173 }
174
175 /**
176 * Returns the sum of squared residuals.
177 *
178 * @return residual sum of squares
179 * @throws org.hipparchus.exception.MathIllegalArgumentException if the design matrix is singular
180 * @throws NullPointerException if the data for the model have not been loaded
181 */
182 public double calculateResidualSumOfSquares() {
183 final RealVector residuals = calculateResiduals();
184 // No advertised DME, args are valid
185 return residuals.dotProduct(residuals);
186 }
187
188 /**
189 * Returns the R-Squared statistic, defined by the formula \(R^2 = 1 - \frac{\mathrm{SSR}}{\mathrm{SSTO}}\)
190 * where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals}
191 * and SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares}
192 *
193 * <p>If there is no variance in y, i.e., SSTO = 0, NaN is returned.</p>
194 *
195 * @return R-square statistic
196 * @throws NullPointerException if the sample has not been set
197 * @throws org.hipparchus.exception.MathIllegalArgumentException if the design matrix is singular
198 */
199 public double calculateRSquared() {
200 return 1 - calculateResidualSumOfSquares() / calculateTotalSumOfSquares();
201 }
202
203 /**
204 * <p>Returns the adjusted R-squared statistic, defined by the formula
205 * \(R_\mathrm{adj}^2 = 1 - \frac{\mathrm{SSR} (n - 1)}{\mathrm{SSTO} (n - p)}\)
206 * where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals},
207 * SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares}, n is the number
208 * of observations and p is the number of parameters estimated (including the intercept).</p>
209 *
210 * <p>If the regression is estimated without an intercept term, what is returned is </p><pre>
211 * <code> 1 - (1 - {@link #calculateRSquared()}) * (n / (n - p)) </code>
212 * </pre>
213 *
214 * <p>If there is no variance in y, i.e., SSTO = 0, NaN is returned.</p>
215 *
216 * @return adjusted R-Squared statistic
217 * @throws NullPointerException if the sample has not been set
218 * @throws org.hipparchus.exception.MathIllegalArgumentException if the design matrix is singular
219 * @see #isNoIntercept()
220 */
221 public double calculateAdjustedRSquared() {
222 final double n = getX().getRowDimension();
223 if (isNoIntercept()) {
224 return 1 - (1 - calculateRSquared()) * (n / (n - getX().getColumnDimension()));
225 } else {
226 return 1 - (calculateResidualSumOfSquares() * (n - 1)) /
227 (calculateTotalSumOfSquares() * (n - getX().getColumnDimension()));
228 }
229 }
230
231 /**
232 * {@inheritDoc}
233 * <p>This implementation computes and caches the QR decomposition of the X matrix
234 * once it is successfully loaded.</p>
235 */
236 @Override
237 protected void newXSampleData(double[][] x) {
238 super.newXSampleData(x);
239 qr = new QRDecomposition(getX(), threshold);
240 }
241
242 /**
243 * Calculates the regression coefficients using OLS.
244 *
245 * <p>Data for the model must have been successfully loaded using one of
246 * the {@code newSampleData} methods before invoking this method; otherwise
247 * a {@code NullPointerException} will be thrown.</p>
248 *
249 * @return beta
250 * @throws org.hipparchus.exception.MathIllegalArgumentException if the design matrix is singular
251 * @throws NullPointerException if the data for the model have not been loaded
252 */
253 @Override
254 protected RealVector calculateBeta() {
255 return qr.getSolver().solve(getY());
256 }
257
258 /**
259 * <p>Calculates the variance-covariance matrix of the regression parameters.
260 * </p>
261 * <p>Var(b) = (X<sup>T</sup>X)<sup>-1</sup>
262 * </p>
263 * <p>Uses QR decomposition to reduce (X<sup>T</sup>X)<sup>-1</sup>
264 * to (R<sup>T</sup>R)<sup>-1</sup>, with only the top p rows of
265 * R included, where p = the length of the beta vector.</p>
266 *
267 * <p>Data for the model must have been successfully loaded using one of
268 * the {@code newSampleData} methods before invoking this method; otherwise
269 * a {@code NullPointerException} will be thrown.</p>
270 *
271 * @return The beta variance-covariance matrix
272 * @throws org.hipparchus.exception.MathIllegalArgumentException if the design matrix is singular
273 * @throws NullPointerException if the data for the model have not been loaded
274 */
275 @Override
276 protected RealMatrix calculateBetaVariance() {
277 int p = getX().getColumnDimension();
278 RealMatrix Raug = qr.getR().getSubMatrix(0, p - 1 , 0, p - 1);
279 RealMatrix Rinv = new LUDecomposition(Raug).getSolver().getInverse();
280 return Rinv.multiplyTransposed(Rinv);
281 }
282
283 }