View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  /*
19   * This is not the original file distributed by the Apache Software Foundation
20   * It has been modified by the Hipparchus project
21   */
22  package org.hipparchus.stat.regression;
23  
24  
25  import org.hipparchus.UnitTestUtils;
26  import org.hipparchus.exception.MathIllegalArgumentException;
27  import org.hipparchus.exception.NullArgumentException;
28  import org.hipparchus.linear.Array2DRowRealMatrix;
29  import org.hipparchus.linear.DefaultRealMatrixChangingVisitor;
30  import org.hipparchus.linear.MatrixUtils;
31  import org.hipparchus.linear.RealMatrix;
32  import org.hipparchus.linear.RealVector;
33  import org.hipparchus.stat.StatUtils;
34  import org.junit.Assert;
35  import org.junit.Before;
36  import org.junit.Test;
37  
38  public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbstractTest {
39  
40      private double[] y;
41      private double[][] x;
42  
43      @Before
44      @Override
45      public void setUp(){
46          y = new double[]{11.0, 12.0, 13.0, 14.0, 15.0, 16.0};
47          x = new double[6][];
48          x[0] = new double[]{0, 0, 0, 0, 0};
49          x[1] = new double[]{2.0, 0, 0, 0, 0};
50          x[2] = new double[]{0, 3.0, 0, 0, 0};
51          x[3] = new double[]{0, 0, 4.0, 0, 0};
52          x[4] = new double[]{0, 0, 0, 5.0, 0};
53          x[5] = new double[]{0, 0, 0, 0, 6.0};
54          super.setUp();
55      }
56  
57      @Override
58      protected OLSMultipleLinearRegression createRegression() {
59          OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
60          regression.newSampleData(y, x);
61          return regression;
62      }
63  
64      @Override
65      protected int getNumberOfRegressors() {
66          return x[0].length + 1;
67      }
68  
69      @Override
70      protected int getSampleSize() {
71          return y.length;
72      }
73  
74      @Test(expected=MathIllegalArgumentException.class)
75      public void cannotAddSampleDataWithSizeMismatch() {
76          double[] y = new double[]{1.0, 2.0};
77          double[][] x = new double[1][];
78          x[0] = new double[]{1.0, 0};
79          createRegression().newSampleData(y, x);
80      }
81  
82      @Test
83      public void testPerfectFit() {
84          double[] betaHat = regression.estimateRegressionParameters();
85          UnitTestUtils.assertEquals(betaHat,
86                                 new double[]{ 11.0, 1.0 / 2.0, 2.0 / 3.0, 3.0 / 4.0, 4.0 / 5.0, 5.0 / 6.0 },
87                                 1e-14);
88          double[] residuals = regression.estimateResiduals();
89          UnitTestUtils.assertEquals(residuals, new double[]{0d,0d,0d,0d,0d,0d},
90                                 1e-14);
91          RealMatrix errors =
92              new Array2DRowRealMatrix(regression.estimateRegressionParametersVariance(), false);
93          final double[] s = { 1.0, -1.0 /  2.0, -1.0 /  3.0, -1.0 /  4.0, -1.0 /  5.0, -1.0 /  6.0 };
94          RealMatrix referenceVariance = new Array2DRowRealMatrix(s.length, s.length);
95          referenceVariance.walkInOptimizedOrder(new DefaultRealMatrixChangingVisitor() {
96              @Override
97              public double visit(int row, int column, double value) {
98                  if (row == 0) {
99                      return s[column];
100                 }
101                 double x = s[row] * s[column];
102                 return (row == column) ? 2 * x : x;
103             }
104         });
105        Assert.assertEquals(0.0,
106                      errors.subtract(referenceVariance).getNorm1(),
107                      5.0e-16 * referenceVariance.getNorm1());
108        Assert.assertEquals(1, ((OLSMultipleLinearRegression) regression).calculateRSquared(), 1E-12);
109     }
110 
111 
112     /**
113      * Test Longley dataset against certified values provided by NIST.
114      * Data Source: J. Longley (1967) "An Appraisal of Least Squares
115      * Programs for the Electronic Computer from the Point of View of the User"
116      * Journal of the American Statistical Association, vol. 62. September,
117      * pp. 819-841.
118      *
119      * Certified values (and data) are from NIST:
120      * <a href="https://www.itl.nist.gov/div898/strd/lls/data/LINKS/DATA/Longley.dat">Longley dataset</a>
121      */
122     @Test
123     public void testLongly() {
124         // Y values are first, then independent vars
125         // Each row is one observation
126         double[] design = new double[] {
127             60323,83.0,234289,2356,1590,107608,1947,
128             61122,88.5,259426,2325,1456,108632,1948,
129             60171,88.2,258054,3682,1616,109773,1949,
130             61187,89.5,284599,3351,1650,110929,1950,
131             63221,96.2,328975,2099,3099,112075,1951,
132             63639,98.1,346999,1932,3594,113270,1952,
133             64989,99.0,365385,1870,3547,115094,1953,
134             63761,100.0,363112,3578,3350,116219,1954,
135             66019,101.2,397469,2904,3048,117388,1955,
136             67857,104.6,419180,2822,2857,118734,1956,
137             68169,108.4,442769,2936,2798,120445,1957,
138             66513,110.8,444546,4681,2637,121950,1958,
139             68655,112.6,482704,3813,2552,123366,1959,
140             69564,114.2,502601,3931,2514,125368,1960,
141             69331,115.7,518173,4806,2572,127852,1961,
142             70551,116.9,554894,4007,2827,130081,1962
143         };
144 
145         final int nobs = 16;
146         final int nvars = 6;
147 
148         // Estimate the model
149         OLSMultipleLinearRegression model = new OLSMultipleLinearRegression();
150         model.newSampleData(design, nobs, nvars);
151 
152         // Check expected beta values from NIST
153         double[] betaHat = model.estimateRegressionParameters();
154         UnitTestUtils.assertEquals(betaHat,
155           new double[]{-3482258.63459582, 15.0618722713733,
156                 -0.358191792925910E-01,-2.02022980381683,
157                 -1.03322686717359,-0.511041056535807E-01,
158                  1829.15146461355}, 2E-8); //
159 
160         // Check expected residuals from R
161         double[] residuals = model.estimateResiduals();
162         UnitTestUtils.assertEquals(residuals, new double[]{
163                 267.340029759711,-94.0139423988359,46.28716775752924,
164                 -410.114621930906,309.7145907602313,-249.3112153297231,
165                 -164.0489563956039,-13.18035686637081,14.30477260005235,
166                  455.394094551857,-17.26892711483297,-39.0550425226967,
167                 -155.5499735953195,-85.6713080421283,341.9315139607727,
168                 -206.7578251937366},
169                       1E-8);
170 
171         // Check standard errors from NIST
172         double[] errors = model.estimateRegressionParametersStandardErrors();
173         UnitTestUtils.assertEquals(new double[] {890420.383607373,
174                        84.9149257747669,
175                        0.334910077722432E-01,
176                        0.488399681651699,
177                        0.214274163161675,
178                        0.226073200069370,
179                        455.478499142212}, errors, 1E-6);
180 
181         // Check regression standard error against R
182         Assert.assertEquals(304.8540735619638, model.estimateRegressionStandardError(), 1E-10);
183 
184         // Check R-Square statistics against R
185         Assert.assertEquals(0.995479004577296, model.calculateRSquared(), 1E-12);
186         Assert.assertEquals(0.992465007628826, model.calculateAdjustedRSquared(), 1E-12);
187 
188         checkVarianceConsistency(model);
189 
190         // Estimate model without intercept
191         model.setNoIntercept(true);
192         model.newSampleData(design, nobs, nvars);
193 
194         // Check expected beta values from R
195         betaHat = model.estimateRegressionParameters();
196         UnitTestUtils.assertEquals(betaHat,
197           new double[]{-52.99357013868291, 0.07107319907358,
198                 -0.42346585566399,-0.57256866841929,
199                 -0.41420358884978, 48.41786562001326}, 1E-11);
200 
201         // Check standard errors from R
202         errors = model.estimateRegressionParametersStandardErrors();
203         UnitTestUtils.assertEquals(new double[] {129.54486693117232, 0.03016640003786,
204                 0.41773654056612, 0.27899087467676, 0.32128496193363,
205                 17.68948737819961}, errors, 1E-11);
206 
207         // Check expected residuals from R
208         residuals = model.estimateResiduals();
209         UnitTestUtils.assertEquals(residuals, new double[]{
210                 279.90274927293092, -130.32465380836874, 90.73228661967445, -401.31252201634948,
211                 -440.46768772620027, -543.54512853774793, 201.32111639536299, 215.90889365977932,
212                 73.09368242049943, 913.21694494481869, 424.82484953610174, -8.56475876776709,
213                 -361.32974610842876, 27.34560497213464, 151.28955976355002, -492.49937355336846},
214                       1E-10);
215 
216         // Check regression standard error against R
217         Assert.assertEquals(475.1655079819517, model.estimateRegressionStandardError(), 1E-10);
218 
219         // Check R-Square statistics against R
220         Assert.assertEquals(0.9999670130706, model.calculateRSquared(), 1E-12);
221         Assert.assertEquals(0.999947220913, model.calculateAdjustedRSquared(), 1E-12);
222 
223     }
224 
225     /**
226      * Test R Swiss fertility dataset against R.
227      * Data Source: R datasets package
228      */
229     @Test
230     public void testSwissFertility() {
231         double[] design = new double[] {
232             80.2,17.0,15,12,9.96,
233             83.1,45.1,6,9,84.84,
234             92.5,39.7,5,5,93.40,
235             85.8,36.5,12,7,33.77,
236             76.9,43.5,17,15,5.16,
237             76.1,35.3,9,7,90.57,
238             83.8,70.2,16,7,92.85,
239             92.4,67.8,14,8,97.16,
240             82.4,53.3,12,7,97.67,
241             82.9,45.2,16,13,91.38,
242             87.1,64.5,14,6,98.61,
243             64.1,62.0,21,12,8.52,
244             66.9,67.5,14,7,2.27,
245             68.9,60.7,19,12,4.43,
246             61.7,69.3,22,5,2.82,
247             68.3,72.6,18,2,24.20,
248             71.7,34.0,17,8,3.30,
249             55.7,19.4,26,28,12.11,
250             54.3,15.2,31,20,2.15,
251             65.1,73.0,19,9,2.84,
252             65.5,59.8,22,10,5.23,
253             65.0,55.1,14,3,4.52,
254             56.6,50.9,22,12,15.14,
255             57.4,54.1,20,6,4.20,
256             72.5,71.2,12,1,2.40,
257             74.2,58.1,14,8,5.23,
258             72.0,63.5,6,3,2.56,
259             60.5,60.8,16,10,7.72,
260             58.3,26.8,25,19,18.46,
261             65.4,49.5,15,8,6.10,
262             75.5,85.9,3,2,99.71,
263             69.3,84.9,7,6,99.68,
264             77.3,89.7,5,2,100.00,
265             70.5,78.2,12,6,98.96,
266             79.4,64.9,7,3,98.22,
267             65.0,75.9,9,9,99.06,
268             92.2,84.6,3,3,99.46,
269             79.3,63.1,13,13,96.83,
270             70.4,38.4,26,12,5.62,
271             65.7,7.7,29,11,13.79,
272             72.7,16.7,22,13,11.22,
273             64.4,17.6,35,32,16.92,
274             77.6,37.6,15,7,4.97,
275             67.6,18.7,25,7,8.65,
276             35.0,1.2,37,53,42.34,
277             44.7,46.6,16,29,50.43,
278             42.8,27.7,22,29,58.33
279         };
280 
281         final int nobs = 47;
282         final int nvars = 4;
283 
284         // Estimate the model
285         OLSMultipleLinearRegression model = new OLSMultipleLinearRegression();
286         model.newSampleData(design, nobs, nvars);
287 
288         // Check expected beta values from R
289         double[] betaHat = model.estimateRegressionParameters();
290         UnitTestUtils.assertEquals(betaHat,
291                 new double[]{91.05542390271397,
292                 -0.22064551045715,
293                 -0.26058239824328,
294                 -0.96161238456030,
295                  0.12441843147162}, 1E-12);
296 
297         // Check expected residuals from R
298         double[] residuals = model.estimateResiduals();
299         UnitTestUtils.assertEquals(residuals, new double[]{
300                 7.1044267859730512,1.6580347433531366,
301                 4.6944952770029644,8.4548022690166160,13.6547432343186212,
302                -9.3586864458500774,7.5822446330520386,15.5568995563859289,
303                 0.8113090736598980,7.1186762732484308,7.4251378771228724,
304                 2.6761316873234109,0.8351584810309354,7.1769991119615177,
305                -3.8746753206299553,-3.1337779476387251,-0.1412575244091504,
306                 1.1186809170469780,-6.3588097346816594,3.4039270429434074,
307                 2.3374058329820175,-7.9272368576900503,-7.8361010968497959,
308                -11.2597369269357070,0.9445333697827101,6.6544245101380328,
309                -0.9146136301118665,-4.3152449403848570,-4.3536932047009183,
310                -3.8907885169304661,-6.3027643926302188,-7.8308982189289091,
311                -3.1792280015332750,-6.7167298771158226,-4.8469946718041754,
312                -10.6335664353633685,11.1031134362036958,6.0084032641811733,
313                 5.4326230830188482,-7.2375578629692230,2.1671550814448222,
314                 15.0147574652763112,4.8625103516321015,-7.1597256413907706,
315                 -0.4515205619767598,-10.2916870903837587,-15.7812984571900063},
316                 1E-12);
317 
318         // Check standard errors from R
319         double[] errors = model.estimateRegressionParametersStandardErrors();
320         UnitTestUtils.assertEquals(new double[] {6.94881329475087,
321                 0.07360008972340,
322                 0.27410957467466,
323                 0.19454551679325,
324                 0.03726654773803}, errors, 1E-10);
325 
326         // Check regression standard error against R
327         Assert.assertEquals(7.73642194433223, model.estimateRegressionStandardError(), 1E-12);
328 
329         // Check R-Square statistics against R
330         Assert.assertEquals(0.649789742860228, model.calculateRSquared(), 1E-12);
331         Assert.assertEquals(0.6164363850373927, model.calculateAdjustedRSquared(), 1E-12);
332 
333         checkVarianceConsistency(model);
334 
335         // Estimate the model with no intercept
336         model = new OLSMultipleLinearRegression();
337         model.setNoIntercept(true);
338         model.newSampleData(design, nobs, nvars);
339 
340         // Check expected beta values from R
341         betaHat = model.estimateRegressionParameters();
342         UnitTestUtils.assertEquals(betaHat,
343                 new double[]{0.52191832900513,
344                   2.36588087917963,
345                   -0.94770353802795,
346                   0.30851985863609}, 1E-12);
347 
348         // Check expected residuals from R
349         residuals = model.estimateResiduals();
350         UnitTestUtils.assertEquals(residuals, new double[]{
351                 44.138759883538249, 27.720705122356215, 35.873200836126799,
352                 34.574619581211977, 26.600168342080213, 15.074636243026923, -12.704904871199814,
353                 1.497443824078134, 2.691972687079431, 5.582798774291231, -4.422986561283165,
354                 -9.198581600334345, 4.481765170730647, 2.273520207553216, -22.649827853221336,
355                 -17.747900013943308, 20.298314638496436, 6.861405135329779, -8.684712790954924,
356                 -10.298639278062371, -9.896618896845819, 4.568568616351242, -15.313570491727944,
357                 -13.762961360873966, 7.156100301980509, 16.722282219843990, 26.716200609071898,
358                 -1.991466398777079, -2.523342564719335, 9.776486693095093, -5.297535127628603,
359                 -16.639070567471094, -10.302057295211819, -23.549487860816846, 1.506624392156384,
360                 -17.939174438345930, 13.105792202765040, -1.943329906928462, -1.516005841666695,
361                 -0.759066561832886, 20.793137744128977, -2.485236153005426, 27.588238710486976,
362                 2.658333257106881, -15.998337823623046, -5.550742066720694, -14.219077806826615},
363                 1E-12);
364 
365         // Check standard errors from R
366         errors = model.estimateRegressionParametersStandardErrors();
367         UnitTestUtils.assertEquals(new double[] {0.10470063765677, 0.41684100584290,
368                 0.43370143099691, 0.07694953606522}, errors, 1E-10);
369 
370         // Check regression standard error against R
371         Assert.assertEquals(17.24710630547, model.estimateRegressionStandardError(), 1E-10);
372 
373         // Check R-Square statistics against R
374         Assert.assertEquals(0.946350722085, model.calculateRSquared(), 1E-12);
375         Assert.assertEquals(0.9413600915813, model.calculateAdjustedRSquared(), 1E-12);
376     }
377 
378     /**
379      * Test hat matrix computation
380      *
381      */
382     @Test
383     public void testHat() {
384 
385         /*
386          * This example is from "The Hat Matrix in Regression and ANOVA",
387          * David C. Hoaglin and Roy E. Welsch,
388          * The American Statistician, Vol. 32, No. 1 (Feb., 1978), pp. 17-22.
389          *
390          */
391         double[] design = new double[] {
392                 11.14, .499, 11.1,
393                 12.74, .558, 8.9,
394                 13.13, .604, 8.8,
395                 11.51, .441, 8.9,
396                 12.38, .550, 8.8,
397                 12.60, .528, 9.9,
398                 11.13, .418, 10.7,
399                 11.7, .480, 10.5,
400                 11.02, .406, 10.5,
401                 11.41, .467, 10.7
402         };
403 
404         int nobs = 10;
405         int nvars = 2;
406 
407         // Estimate the model
408         OLSMultipleLinearRegression model = new OLSMultipleLinearRegression();
409         model.newSampleData(design, nobs, nvars);
410 
411         RealMatrix hat = model.calculateHat();
412 
413         // Reference data is upper half of symmetric hat matrix
414         double[] referenceData = new double[] {
415                 .418, -.002,  .079, -.274, -.046,  .181,  .128,  .222,  .050,  .242,
416                        .242,  .292,  .136,  .243,  .128, -.041,  .033, -.035,  .004,
417                               .417, -.019,  .273,  .187, -.126,  .044, -.153,  .004,
418                                      .604,  .197, -.038,  .168, -.022,  .275, -.028,
419                                             .252,  .111, -.030,  .019, -.010, -.010,
420                                                    .148,  .042,  .117,  .012,  .111,
421                                                           .262,  .145,  .277,  .174,
422                                                                  .154,  .120,  .168,
423                                                                         .315,  .148,
424                                                                                .187
425         };
426 
427         // Check against reference data and verify symmetry
428         int k = 0;
429         for (int i = 0; i < 10; i++) {
430             for (int j = i; j < 10; j++) {
431                 Assert.assertEquals(referenceData[k], hat.getEntry(i, j), 10e-3);
432                 Assert.assertEquals(hat.getEntry(i, j), hat.getEntry(j, i), 10e-12);
433                 k++;
434             }
435         }
436 
437         /*
438          * Verify that residuals computed using the hat matrix are close to
439          * what we get from direct computation, i.e. r = (I - H) y
440          */
441         double[] residuals = model.estimateResiduals();
442         RealMatrix I = MatrixUtils.createRealIdentityMatrix(10);
443         double[] hatResiduals = I.subtract(hat).operate(model.getY()).toArray();
444         UnitTestUtils.assertEquals(residuals, hatResiduals, 10e-12);
445     }
446 
447     /**
448      * test calculateYVariance
449      */
450     @Test
451     public void testYVariance() {
452 
453         // assumes: y = new double[]{11.0, 12.0, 13.0, 14.0, 15.0, 16.0};
454 
455         OLSMultipleLinearRegression model = new OLSMultipleLinearRegression();
456         model.newSampleData(y, x);
457         UnitTestUtils.assertEquals(model.calculateYVariance(), 3.5, 0);
458     }
459 
460     /**
461      * Verifies that calculateYVariance and calculateResidualVariance return consistent
462      * values with direct variance computation from Y, residuals, respectively.
463      */
464     protected void checkVarianceConsistency(OLSMultipleLinearRegression model) {
465         // Check Y variance consistency
466         UnitTestUtils.assertEquals(StatUtils.variance(model.getY().toArray()), model.calculateYVariance(), 0);
467 
468         // Check residual variance consistency
469         double[] residuals = model.calculateResiduals().toArray();
470         RealMatrix X = model.getX();
471         UnitTestUtils.assertEquals(
472                 StatUtils.variance(model.calculateResiduals().toArray()) * (residuals.length - 1),
473                 model.calculateErrorVariance() * (X.getRowDimension() - X.getColumnDimension()), 1E-20);
474 
475     }
476 
477     /**
478      * Verifies that setting X and Y separately has the same effect as newSample(X,Y).
479      */
480     @Test
481     public void testNewSample2() {
482         double[] y = new double[] {1, 2, 3, 4};
483         double[][] x = new double[][] {
484           {19, 22, 33},
485           {20, 30, 40},
486           {25, 35, 45},
487           {27, 37, 47}
488         };
489         OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
490         regression.newSampleData(y, x);
491         RealMatrix combinedX = regression.getX().copy();
492         RealVector combinedY = regression.getY().copy();
493         regression.newXSampleData(x);
494         regression.newYSampleData(y);
495         Assert.assertEquals(combinedX, regression.getX());
496         Assert.assertEquals(combinedY, regression.getY());
497 
498         // No intercept
499         regression.setNoIntercept(true);
500         regression.newSampleData(y, x);
501         combinedX = regression.getX().copy();
502         combinedY = regression.getY().copy();
503         regression.newXSampleData(x);
504         regression.newYSampleData(y);
505         Assert.assertEquals(combinedX, regression.getX());
506         Assert.assertEquals(combinedY, regression.getY());
507     }
508 
509     @Test(expected=NullArgumentException.class)
510     public void testNewSampleDataYNull() {
511         createRegression().newSampleData(null, new double[][] {});
512     }
513 
514     @Test(expected=NullArgumentException.class)
515     public void testNewSampleDataXNull() {
516         createRegression().newSampleData(new double[] {}, null);
517     }
518 
519      /*
520      * This is a test based on the Wampler1 data set
521      * http://www.itl.nist.gov/div898/strd/lls/data/Wampler1.shtml
522      */
523     @Test
524     public void testWampler1() {
525         double[] data = new double[]{
526             1, 0,
527             6, 1,
528             63, 2,
529             364, 3,
530             1365, 4,
531             3906, 5,
532             9331, 6,
533             19608, 7,
534             37449, 8,
535             66430, 9,
536             111111, 10,
537             177156, 11,
538             271453, 12,
539             402234, 13,
540             579195, 14,
541             813616, 15,
542             1118481, 16,
543             1508598, 17,
544             2000719, 18,
545             2613660, 19,
546             3368421, 20};
547         OLSMultipleLinearRegression model = new OLSMultipleLinearRegression();
548 
549 
550         final int nvars = 5;
551         final int nobs = 21;
552         double[] tmp = new double[(nvars + 1) * nobs];
553         int off = 0;
554         int off2 = 0;
555         for (int i = 0; i < nobs; i++) {
556             tmp[off2] = data[off];
557             tmp[off2 + 1] = data[off + 1];
558             tmp[off2 + 2] = tmp[off2 + 1] * tmp[off2 + 1];
559             tmp[off2 + 3] = tmp[off2 + 1] * tmp[off2 + 2];
560             tmp[off2 + 4] = tmp[off2 + 1] * tmp[off2 + 3];
561             tmp[off2 + 5] = tmp[off2 + 1] * tmp[off2 + 4];
562             off2 += (nvars + 1);
563             off += 2;
564         }
565         model.newSampleData(tmp, nobs, nvars);
566         double[] betaHat = model.estimateRegressionParameters();
567         UnitTestUtils.assertEquals(betaHat,
568                 new double[]{1.0,
569                     1.0, 1.0,
570                     1.0, 1.0,
571                     1.0}, 1E-8);
572 
573         double[] se = model.estimateRegressionParametersStandardErrors();
574         UnitTestUtils.assertEquals(se,
575                 new double[]{0.0,
576                     0.0, 0.0,
577                     0.0, 0.0,
578                     0.0}, 1E-8);
579 
580         UnitTestUtils.assertEquals(1.0, model.calculateRSquared(), 1.0e-10);
581         UnitTestUtils.assertEquals(0, model.estimateErrorVariance(), 1.0e-7);
582         UnitTestUtils.assertEquals(0.00, model.calculateResidualSumOfSquares(), 1.0e-6);
583 
584         return;
585     }
586 
587     /*
588      * This is a test based on the Wampler2 data set
589      * http://www.itl.nist.gov/div898/strd/lls/data/Wampler2.shtml
590      */
591     @Test
592     public void testWampler2() {
593         double[] data = new double[]{
594             1.00000, 0,
595             1.11111, 1,
596             1.24992, 2,
597             1.42753, 3,
598             1.65984, 4,
599             1.96875, 5,
600             2.38336, 6,
601             2.94117, 7,
602             3.68928, 8,
603             4.68559, 9,
604             6.00000, 10,
605             7.71561, 11,
606             9.92992, 12,
607             12.75603, 13,
608             16.32384, 14,
609             20.78125, 15,
610             26.29536, 16,
611             33.05367, 17,
612             41.26528, 18,
613             51.16209, 19,
614             63.00000, 20};
615         OLSMultipleLinearRegression model = new OLSMultipleLinearRegression();
616 
617 
618         final int nvars = 5;
619         final int nobs = 21;
620         double[] tmp = new double[(nvars + 1) * nobs];
621         int off = 0;
622         int off2 = 0;
623         for (int i = 0; i < nobs; i++) {
624             tmp[off2] = data[off];
625             tmp[off2 + 1] = data[off + 1];
626             tmp[off2 + 2] = tmp[off2 + 1] * tmp[off2 + 1];
627             tmp[off2 + 3] = tmp[off2 + 1] * tmp[off2 + 2];
628             tmp[off2 + 4] = tmp[off2 + 1] * tmp[off2 + 3];
629             tmp[off2 + 5] = tmp[off2 + 1] * tmp[off2 + 4];
630             off2 += (nvars + 1);
631             off += 2;
632         }
633         model.newSampleData(tmp, nobs, nvars);
634         double[] betaHat = model.estimateRegressionParameters();
635         UnitTestUtils.assertEquals(betaHat,
636                 new double[]{
637                     1.0,
638                     1.0e-1,
639                     1.0e-2,
640                     1.0e-3, 1.0e-4,
641                     1.0e-5}, 1E-8);
642 
643         double[] se = model.estimateRegressionParametersStandardErrors();
644         UnitTestUtils.assertEquals(se,
645                 new double[]{0.0,
646                     0.0, 0.0,
647                     0.0, 0.0,
648                     0.0}, 1E-8);
649         UnitTestUtils.assertEquals(1.0, model.calculateRSquared(), 1.0e-10);
650         UnitTestUtils.assertEquals(0, model.estimateErrorVariance(), 1.0e-7);
651         UnitTestUtils.assertEquals(0.00, model.calculateResidualSumOfSquares(), 1.0e-6);
652         return;
653     }
654 
655     /*
656      * This is a test based on the Wampler3 data set
657      * http://www.itl.nist.gov/div898/strd/lls/data/Wampler3.shtml
658      */
659     @Test
660     public void testWampler3() {
661         double[] data = new double[]{
662             760, 0,
663             -2042, 1,
664             2111, 2,
665             -1684, 3,
666             3888, 4,
667             1858, 5,
668             11379, 6,
669             17560, 7,
670             39287, 8,
671             64382, 9,
672             113159, 10,
673             175108, 11,
674             273291, 12,
675             400186, 13,
676             581243, 14,
677             811568, 15,
678             1121004, 16,
679             1506550, 17,
680             2002767, 18,
681             2611612, 19,
682             3369180, 20};
683 
684         OLSMultipleLinearRegression model = new OLSMultipleLinearRegression();
685         final int nvars = 5;
686         final int nobs = 21;
687         double[] tmp = new double[(nvars + 1) * nobs];
688         int off = 0;
689         int off2 = 0;
690         for (int i = 0; i < nobs; i++) {
691             tmp[off2] = data[off];
692             tmp[off2 + 1] = data[off + 1];
693             tmp[off2 + 2] = tmp[off2 + 1] * tmp[off2 + 1];
694             tmp[off2 + 3] = tmp[off2 + 1] * tmp[off2 + 2];
695             tmp[off2 + 4] = tmp[off2 + 1] * tmp[off2 + 3];
696             tmp[off2 + 5] = tmp[off2 + 1] * tmp[off2 + 4];
697             off2 += (nvars + 1);
698             off += 2;
699         }
700         model.newSampleData(tmp, nobs, nvars);
701         double[] betaHat = model.estimateRegressionParameters();
702         UnitTestUtils.assertEquals(betaHat,
703                 new double[]{
704                     1.0,
705                     1.0,
706                     1.0,
707                     1.0,
708                     1.0,
709                     1.0}, 1E-8);
710 
711         double[] se = model.estimateRegressionParametersStandardErrors();
712         UnitTestUtils.assertEquals(se,
713                 new double[]{2152.32624678170,
714                     2363.55173469681, 779.343524331583,
715                     101.475507550350, 5.64566512170752,
716                     0.112324854679312}, 1E-8); //
717 
718         UnitTestUtils.assertEquals(.999995559025820, model.calculateRSquared(), 1.0e-10);
719         UnitTestUtils.assertEquals(5570284.53333333, model.estimateErrorVariance(), 1.0e-6);
720         UnitTestUtils.assertEquals(83554268.0000000, model.calculateResidualSumOfSquares(), 1.0e-5);
721         return;
722     }
723 
724     /*
725      * This is a test based on the Wampler4 data set
726      * http://www.itl.nist.gov/div898/strd/lls/data/Wampler4.shtml
727      */
728     @Test
729     public void testWampler4() {
730         double[] data = new double[]{
731             75901, 0,
732             -204794, 1,
733             204863, 2,
734             -204436, 3,
735             253665, 4,
736             -200894, 5,
737             214131, 6,
738             -185192, 7,
739             221249, 8,
740             -138370, 9,
741             315911, 10,
742             -27644, 11,
743             455253, 12,
744             197434, 13,
745             783995, 14,
746             608816, 15,
747             1370781, 16,
748             1303798, 17,
749             2205519, 18,
750             2408860, 19,
751             3444321, 20};
752 
753         OLSMultipleLinearRegression model = new OLSMultipleLinearRegression();
754         final int nvars = 5;
755         final int nobs = 21;
756         double[] tmp = new double[(nvars + 1) * nobs];
757         int off = 0;
758         int off2 = 0;
759         for (int i = 0; i < nobs; i++) {
760             tmp[off2] = data[off];
761             tmp[off2 + 1] = data[off + 1];
762             tmp[off2 + 2] = tmp[off2 + 1] * tmp[off2 + 1];
763             tmp[off2 + 3] = tmp[off2 + 1] * tmp[off2 + 2];
764             tmp[off2 + 4] = tmp[off2 + 1] * tmp[off2 + 3];
765             tmp[off2 + 5] = tmp[off2 + 1] * tmp[off2 + 4];
766             off2 += (nvars + 1);
767             off += 2;
768         }
769         model.newSampleData(tmp, nobs, nvars);
770         double[] betaHat = model.estimateRegressionParameters();
771         UnitTestUtils.assertEquals(betaHat,
772                 new double[]{
773                     1.0,
774                     1.0,
775                     1.0,
776                     1.0,
777                     1.0,
778                     1.0}, 1E-6);
779 
780         double[] se = model.estimateRegressionParametersStandardErrors();
781         UnitTestUtils.assertEquals(se,
782                 new double[]{215232.624678170,
783                     236355.173469681, 77934.3524331583,
784                     10147.5507550350, 564.566512170752,
785                     11.2324854679312}, 1E-8);
786 
787         UnitTestUtils.assertEquals(.957478440825662, model.calculateRSquared(), 1.0e-10);
788         UnitTestUtils.assertEquals(55702845333.3333, model.estimateErrorVariance(), 1.0e-4);
789         UnitTestUtils.assertEquals(835542680000.000, model.calculateResidualSumOfSquares(), 1.0e-3);
790         return;
791     }
792 
793     /**
794      * Anything requiring beta calculation should advertise SME.
795      */
796     @Test(expected=MathIllegalArgumentException.class)
797     public void testSingularCalculateBeta() {
798         OLSMultipleLinearRegression model = new OLSMultipleLinearRegression();
799         model.newSampleData(new double[] {1,  2,  3, 1, 2, 3, 1, 2, 3}, 3, 2);
800         model.calculateBeta();
801     }
802 
803     @Test
804     public void testNoSSTOCalculateRsquare() {
805         OLSMultipleLinearRegression model = new OLSMultipleLinearRegression();
806         model.newSampleData(new double[] {1,  2,  3, 1, 7, 8, 1, 10, 12}, 3, 2);
807         Assert.assertTrue(Double.isNaN(model.calculateRSquared()));
808     }
809 
810     @Test(expected=NullPointerException.class)
811     public void testNoDataNPECalculateBeta() {
812         OLSMultipleLinearRegression model = new OLSMultipleLinearRegression();
813         model.calculateBeta();
814     }
815 
816     @Test(expected=NullPointerException.class)
817     public void testNoDataNPECalculateHat() {
818         OLSMultipleLinearRegression model = new OLSMultipleLinearRegression();
819         model.calculateHat();
820     }
821 
822     @Test(expected=NullPointerException.class)
823     public void testNoDataNPESSTO() {
824         OLSMultipleLinearRegression model = new OLSMultipleLinearRegression();
825         model.calculateTotalSumOfSquares();
826     }
827 
828     /**
829      * From <a href="https://stackoverflow.com/questions/37320008/ols-multiple-linear-regression-with-commons-math">OLS
830      * Multiple Linear Regression with commons-math</a>
831      */
832     @Test
833     public void testNewSampleDataNoIntercept() {
834         final double[][] x = { { 1, 0, 0, 0 }, { 0, 1, 0, 0 }, { 0, 0, 1, 0}, {  0, 0, 0, 1 } };
835         final double[] y = { 1, 2, 3, 4 };
836 
837         final OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
838         regression.setNoIntercept(true);
839         regression.newSampleData(y, x);
840 
841         final double[] b = regression.estimateRegressionParameters();
842         for (int i = 0; i < y.length; i++) {
843             Assert.assertEquals(b[i], y[i], 0.001);
844         }
845     }
846 }