1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * https://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18 /* 19 * This is not the original file distributed by the Apache Software Foundation 20 * It has been modified by the Hipparchus project 21 */ 22 package org.hipparchus.stat.regression; 23 24 import org.hipparchus.exception.MathIllegalArgumentException; 25 import org.hipparchus.linear.Array2DRowRealMatrix; 26 import org.hipparchus.linear.LUDecomposition; 27 import org.hipparchus.linear.QRDecomposition; 28 import org.hipparchus.linear.RealMatrix; 29 import org.hipparchus.linear.RealVector; 30 import org.hipparchus.stat.StatUtils; 31 import org.hipparchus.stat.descriptive.moment.SecondMoment; 32 33 /** 34 * <p>Implements ordinary least squares (OLS) to estimate the parameters of a 35 * multiple linear regression model.</p> 36 * 37 * <p>The regression coefficients, <code>b</code>, satisfy the normal equations:</p> 38 * <pre><code> X<sup>T</sup> X b = X<sup>T</sup> y </code></pre> 39 * 40 * <p> 41 * To solve the normal equations, this implementation uses QR decomposition 42 * of the <code>X</code> matrix. (See {@link QRDecomposition} for details on the 43 * decomposition algorithm.) The <code>X</code> matrix, also known as the <i>design matrix,</i> 44 * has rows corresponding to sample observations and columns corresponding to independent 45 * variables. When the model is estimated using an intercept term (i.e. when 46 * {@link #isNoIntercept() isNoIntercept} is false as it is by default), the <code>X</code> 47 * matrix includes an initial column identically equal to 1. We solve the normal equations 48 * as follows: 49 * </p> 50 * <pre><code> X<sup>T</sup>X b = X<sup>T</sup> y 51 * (QR)<sup>T</sup> (QR) b = (QR)<sup>T</sup>y 52 * R<sup>T</sup> (Q<sup>T</sup>Q) R b = R<sup>T</sup> Q<sup>T</sup> y 53 * R<sup>T</sup> R b = R<sup>T</sup> Q<sup>T</sup> y 54 * (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> R b = (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> Q<sup>T</sup> y 55 * R b = Q<sup>T</sup> y </code></pre> 56 * 57 * <p>Given <code>Q</code> and <code>R</code>, the last equation is solved by back-substitution.</p> 58 * 59 */ 60 public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegression { 61 62 /** Cached QR decomposition of X matrix */ 63 private QRDecomposition qr; 64 65 /** Singularity threshold for QR decomposition */ 66 private final double threshold; 67 68 /** 69 * Create an empty OLSMultipleLinearRegression instance. 70 */ 71 public OLSMultipleLinearRegression() { 72 this(0d); 73 } 74 75 /** 76 * Create an empty OLSMultipleLinearRegression instance, using the given 77 * singularity threshold for the QR decomposition. 78 * 79 * @param threshold the singularity threshold 80 */ 81 public OLSMultipleLinearRegression(final double threshold) { 82 this.threshold = threshold; 83 } 84 85 /** 86 * Loads model x and y sample data, overriding any previous sample. 87 * 88 * Computes and caches QR decomposition of the X matrix. 89 * @param y the [n,1] array representing the y sample 90 * @param x the [n,k] array representing the x sample 91 * @throws MathIllegalArgumentException if the x and y array data are not 92 * compatible for the regression 93 */ 94 public void newSampleData(double[] y, double[][] x) throws MathIllegalArgumentException { 95 validateSampleData(x, y); 96 newYSampleData(y); 97 newXSampleData(x); 98 } 99 100 /** 101 * {@inheritDoc} 102 * <p>This implementation computes and caches the QR decomposition of the X matrix.</p> 103 */ 104 @Override 105 public void newSampleData(double[] data, int nobs, int nvars) { 106 super.newSampleData(data, nobs, nvars); 107 qr = new QRDecomposition(getX(), threshold); 108 } 109 110 /** 111 * <p>Compute the "hat" matrix. 112 * </p> 113 * <p>The hat matrix is defined in terms of the design matrix X 114 * by X(X<sup>T</sup>X)<sup>-1</sup>X<sup>T</sup> 115 * </p> 116 * <p>The implementation here uses the QR decomposition to compute the 117 * hat matrix as Q I<sub>p</sub>Q<sup>T</sup> where I<sub>p</sub> is the 118 * p-dimensional identity matrix augmented by 0's. This computational 119 * formula is from "The Hat Matrix in Regression and ANOVA", 120 * David C. Hoaglin and Roy E. Welsch, 121 * <i>The American Statistician</i>, Vol. 32, No. 1 (Feb., 1978), pp. 17-22. 122 * </p> 123 * <p>Data for the model must have been successfully loaded using one of 124 * the {@code newSampleData} methods before invoking this method; otherwise 125 * a {@code NullPointerException} will be thrown.</p> 126 * 127 * @return the hat matrix 128 * @throws NullPointerException unless method {@code newSampleData} has been 129 * called beforehand. 130 */ 131 public RealMatrix calculateHat() { 132 // Create augmented identity matrix 133 RealMatrix Q = qr.getQ(); 134 final int p = qr.getR().getColumnDimension(); 135 final int n = Q.getColumnDimension(); 136 // No try-catch or advertised MathIllegalArgumentException - NPE above if n < 3 137 Array2DRowRealMatrix augI = new Array2DRowRealMatrix(n, n); 138 double[][] augIData = augI.getDataRef(); 139 for (int i = 0; i < n; i++) { 140 for (int j =0; j < n; j++) { 141 if (i == j && i < p) { 142 augIData[i][j] = 1d; 143 } else { 144 augIData[i][j] = 0d; 145 } 146 } 147 } 148 149 // Compute and return Hat matrix 150 // No DME advertised - args valid if we get here 151 return Q.multiply(augI).multiplyTransposed(Q); 152 } 153 154 /** 155 * <p>Returns the sum of squared deviations of Y from its mean.</p> 156 * 157 * <p>If the model has no intercept term, <code>0</code> is used for the 158 * mean of Y - i.e., what is returned is the sum of the squared Y values.</p> 159 * 160 * <p>The value returned by this method is the SSTO value used in 161 * the {@link #calculateRSquared() R-squared} computation.</p> 162 * 163 * @return SSTO - the total sum of squares 164 * @throws NullPointerException if the sample has not been set 165 * @see #isNoIntercept() 166 */ 167 public double calculateTotalSumOfSquares() { 168 if (isNoIntercept()) { 169 return StatUtils.sumSq(getY().toArray()); 170 } else { 171 return new SecondMoment().evaluate(getY().toArray()); 172 } 173 } 174 175 /** 176 * Returns the sum of squared residuals. 177 * 178 * @return residual sum of squares 179 * @throws org.hipparchus.exception.MathIllegalArgumentException if the design matrix is singular 180 * @throws NullPointerException if the data for the model have not been loaded 181 */ 182 public double calculateResidualSumOfSquares() { 183 final RealVector residuals = calculateResiduals(); 184 // No advertised DME, args are valid 185 return residuals.dotProduct(residuals); 186 } 187 188 /** 189 * Returns the R-Squared statistic, defined by the formula \(R^2 = 1 - \frac{\mathrm{SSR}}{\mathrm{SSTO}}\) 190 * where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals} 191 * and SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares} 192 * 193 * <p>If there is no variance in y, i.e., SSTO = 0, NaN is returned.</p> 194 * 195 * @return R-square statistic 196 * @throws NullPointerException if the sample has not been set 197 * @throws org.hipparchus.exception.MathIllegalArgumentException if the design matrix is singular 198 */ 199 public double calculateRSquared() { 200 return 1 - calculateResidualSumOfSquares() / calculateTotalSumOfSquares(); 201 } 202 203 /** 204 * <p>Returns the adjusted R-squared statistic, defined by the formula 205 * \(R_\mathrm{adj}^2 = 1 - \frac{\mathrm{SSR} (n - 1)}{\mathrm{SSTO} (n - p)}\) 206 * where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals}, 207 * SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares}, n is the number 208 * of observations and p is the number of parameters estimated (including the intercept).</p> 209 * 210 * <p>If the regression is estimated without an intercept term, what is returned is </p><pre> 211 * <code> 1 - (1 - {@link #calculateRSquared()}) * (n / (n - p)) </code> 212 * </pre> 213 * 214 * <p>If there is no variance in y, i.e., SSTO = 0, NaN is returned.</p> 215 * 216 * @return adjusted R-Squared statistic 217 * @throws NullPointerException if the sample has not been set 218 * @throws org.hipparchus.exception.MathIllegalArgumentException if the design matrix is singular 219 * @see #isNoIntercept() 220 */ 221 public double calculateAdjustedRSquared() { 222 final double n = getX().getRowDimension(); 223 if (isNoIntercept()) { 224 return 1 - (1 - calculateRSquared()) * (n / (n - getX().getColumnDimension())); 225 } else { 226 return 1 - (calculateResidualSumOfSquares() * (n - 1)) / 227 (calculateTotalSumOfSquares() * (n - getX().getColumnDimension())); 228 } 229 } 230 231 /** 232 * {@inheritDoc} 233 * <p>This implementation computes and caches the QR decomposition of the X matrix 234 * once it is successfully loaded.</p> 235 */ 236 @Override 237 protected void newXSampleData(double[][] x) { 238 super.newXSampleData(x); 239 qr = new QRDecomposition(getX(), threshold); 240 } 241 242 /** 243 * Calculates the regression coefficients using OLS. 244 * 245 * <p>Data for the model must have been successfully loaded using one of 246 * the {@code newSampleData} methods before invoking this method; otherwise 247 * a {@code NullPointerException} will be thrown.</p> 248 * 249 * @return beta 250 * @throws org.hipparchus.exception.MathIllegalArgumentException if the design matrix is singular 251 * @throws NullPointerException if the data for the model have not been loaded 252 */ 253 @Override 254 protected RealVector calculateBeta() { 255 return qr.getSolver().solve(getY()); 256 } 257 258 /** 259 * <p>Calculates the variance-covariance matrix of the regression parameters. 260 * </p> 261 * <p>Var(b) = (X<sup>T</sup>X)<sup>-1</sup> 262 * </p> 263 * <p>Uses QR decomposition to reduce (X<sup>T</sup>X)<sup>-1</sup> 264 * to (R<sup>T</sup>R)<sup>-1</sup>, with only the top p rows of 265 * R included, where p = the length of the beta vector.</p> 266 * 267 * <p>Data for the model must have been successfully loaded using one of 268 * the {@code newSampleData} methods before invoking this method; otherwise 269 * a {@code NullPointerException} will be thrown.</p> 270 * 271 * @return The beta variance-covariance matrix 272 * @throws org.hipparchus.exception.MathIllegalArgumentException if the design matrix is singular 273 * @throws NullPointerException if the data for the model have not been loaded 274 */ 275 @Override 276 protected RealMatrix calculateBetaVariance() { 277 int p = getX().getColumnDimension(); 278 RealMatrix Raug = qr.getR().getSubMatrix(0, p - 1 , 0, p - 1); 279 RealMatrix Rinv = new LUDecomposition(Raug).getSolver().getInverse(); 280 return Rinv.multiplyTransposed(Rinv); 281 } 282 283 }