View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements. See the NOTICE file distributed with this
4    * work for additional information regarding copyright ownership. The ASF
5    * licenses this file to You under the Apache License, Version 2.0 (the
6    * "License"); you may not use this file except in compliance with the License.
7    * You may obtain a copy of the License at
8    * https://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law
9    * or agreed to in writing, software distributed under the License is
10   * distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
11   * KIND, either express or implied. See the License for the specific language
12   * governing permissions and limitations under the License.
13   */
14  package org.hipparchus.optim.nonlinear.vector.leastsquares;
15  
16  import java.awt.geom.Point2D;
17  import java.util.ArrayList;
18  import java.util.List;
19  
20  import org.hipparchus.UnitTestUtils;
21  import org.hipparchus.linear.ArrayRealVector;
22  import org.hipparchus.linear.DiagonalMatrix;
23  import org.hipparchus.linear.RealVector;
24  import org.hipparchus.util.FastMath;
25  import org.junit.Assert;
26  import org.junit.Ignore;
27  import org.junit.Test;
28  
29  /**
30   * This class demonstrates the main functionality of the
31   * {@link LeastSquaresProblem.Evaluation}, common to the
32   * optimizer implementations in package
33   * {@link org.hipparchus.optim.nonlinear.vector.leastsquares}.
34   * <br>
35   * Not enabled by default, as the class name does not end with "Test".
36   * <br>
37   * Invoke by running
38   * <pre><code>
39   *  mvn test -Dtest=EvaluationTestValidation
40   * </code></pre>
41   * or by running
42   * <pre><code>
43   *  mvn test -Dtest=EvaluationTestValidation -DargLine="-DmcRuns=1234 -server"
44   * </code></pre>
45   */
46  public class EvaluationTestValidation {
47      /** Number of runs. */
48      private static final int MONTE_CARLO_RUNS = Integer.parseInt(System.getProperty("mcRuns",
49                                                                                      "100"));
50  
51      /**
52       * Using a Monte-Carlo procedure, this test checks the error estimations
53       * as provided by the square-root of the diagonal elements of the
54       * covariance matrix.
55       * <br>
56       * The test generates sets of observations, each sampled from
57       * a Gaussian distribution.
58       * <br>
59       * The optimization problem solved is defined in class
60       * {@link StraightLineProblem}.
61       * <br>
62       * The output (on stdout) will be a table summarizing the distribution
63       * of parameters generated by the Monte-Carlo process and by the direct
64       * estimation provided by the diagonal elements of the covariance matrix.
65       */
66      @Ignore
67      @Test
68      public void testParametersErrorMonteCarloObservations() {
69          // Error on the observations.
70          final double yError = 15;
71  
72          // True values of the parameters.
73          final double slope = 123.456;
74          final double offset = -98.765;
75  
76          // Samples generator.
77          final RandomStraightLinePointGenerator lineGenerator
78              = new RandomStraightLinePointGenerator(slope, offset,
79                                                     yError,
80                                                     -1e3, 1e4,
81                                                     138577L);
82  
83          // Number of observations.
84          final int numObs = 100; // XXX Should be a command-line option.
85          // number of parameters.
86          final int numParams = 2;
87  
88          // Parameters found for each of Monte-Carlo run.
89          final UnitTestUtils.SimpleStatistics[] paramsFoundByDirectSolution = new UnitTestUtils.SimpleStatistics[numParams];
90          // Sigma estimations (square-root of the diagonal elements of the
91          // covariance matrix), for each Monte-Carlo run.
92          final UnitTestUtils.SimpleStatistics[] sigmaEstimate = new UnitTestUtils.SimpleStatistics[numParams];
93  
94          // Initialize statistics accumulators.
95          for (int i = 0; i < numParams; i++) {
96              paramsFoundByDirectSolution[i] = new UnitTestUtils.SimpleStatistics();
97              sigmaEstimate[i] = new UnitTestUtils.SimpleStatistics();
98          }
99  
100         final RealVector init = new ArrayRealVector(new double[]{ slope, offset }, false);
101 
102         // Monte-Carlo (generates many sets of observations).
103         final int mcRepeat = MONTE_CARLO_RUNS;
104         int mcCount = 0;
105         while (mcCount < mcRepeat) {
106             // Observations.
107             final Point2D.Double[] obs = lineGenerator.generate(numObs);
108 
109             final StraightLineProblem problem = new StraightLineProblem(yError);
110             for (int i = 0; i < numObs; i++) {
111                 final Point2D.Double p = obs[i];
112                 problem.addPoint(p.x, p.y);
113             }
114 
115             // Direct solution (using simple regression).
116             final double[] regress = problem.solve();
117 
118             // Estimation of the standard deviation (diagonal elements of the
119             // covariance matrix).
120             final LeastSquaresProblem lsp = builder(problem).build();
121 
122             final RealVector sigma = lsp.evaluate(init).getSigma(1e-14);
123 
124             // Accumulate statistics.
125             for (int i = 0; i < numParams; i++) {
126                 paramsFoundByDirectSolution[i].addValue(regress[i]);
127                 sigmaEstimate[i].addValue(sigma.getEntry(i));
128             }
129 
130             // Next Monte-Carlo.
131             ++mcCount;
132         }
133 
134         // Print statistics.
135         final String line = "--------------------------------------------------------------";
136         System.out.println("                 True value       Mean        Std deviation");
137         for (int i = 0; i < numParams; i++) {
138             System.out.println(line);
139             System.out.println("Parameter #" + i);
140 
141             System.out.printf("              %+.6e   %+.6e   %+.6e\n",
142                               init.getEntry(i),
143                               paramsFoundByDirectSolution[i].getMean(),
144                               paramsFoundByDirectSolution[i].getStandardDeviation());
145 
146             System.out.printf("sigma: %+.6e (%+.6e)\n",
147                               sigmaEstimate[i].getMean(),
148                               sigmaEstimate[i].getStandardDeviation());
149         }
150         System.out.println(line);
151 
152         // Check the error estimation.
153         for (int i = 0; i < numParams; i++) {
154             Assert.assertEquals(paramsFoundByDirectSolution[i].getStandardDeviation(),
155                                 sigmaEstimate[i].getMean(),
156                                 8e-2);
157         }
158     }
159 
160     /**
161      * In this test, the set of observations is fixed.
162      * Using a Monte-Carlo procedure, it generates sets of parameters,
163      * and determine the parameter change that will result in the
164      * normalized chi-square becoming larger by one than the value from
165      * the best fit solution.
166      * <br>
167      * The optimization problem solved is defined in class
168      * {@link StraightLineProblem}.
169      * <br>
170      * The output (on stdout) will be a list of lines containing:
171      * <ul>
172      *  <li>slope of the straight line,</li>
173      *  <li>intercept of the straight line,</li>
174      *  <li>chi-square of the solution defined by the above two values.</li>
175      * </ul>
176      * The output is separated into two blocks (with a blank line between
177      * them); the first block will contain all parameter sets for which
178      * {@code chi2 < chi2_b + 1}
179      * and the second block, all sets for which
180      * {@code chi2 >= chi2_b + 1}
181      * where {@code chi2_b} is the lowest chi-square (corresponding to the
182      * best solution).
183      */
184     @Ignore
185     @Test
186     public void testParametersErrorMonteCarloParameters() {
187         // Error on the observations.
188         final double yError = 15;
189 
190         // True values of the parameters.
191         final double slope = 123.456;
192         final double offset = -98.765;
193 
194         // Samples generator.
195         final RandomStraightLinePointGenerator lineGenerator
196             = new RandomStraightLinePointGenerator(slope, offset,
197                                                    yError,
198                                                    -1e3, 1e4,
199                                                    13839013L);
200 
201         // Number of observations.
202         final int numObs = 10;
203         // number of parameters.
204 
205         // Create a single set of observations.
206         final Point2D.Double[] obs = lineGenerator.generate(numObs);
207 
208         final StraightLineProblem problem = new StraightLineProblem(yError);
209         for (int i = 0; i < numObs; i++) {
210             final Point2D.Double p = obs[i];
211             problem.addPoint(p.x, p.y);
212         }
213 
214         // Direct solution (using simple regression).
215         final RealVector regress = new ArrayRealVector(problem.solve(), false);
216 
217         // Dummy optimizer (to compute the chi-square).
218         final LeastSquaresProblem lsp = builder(problem).build();
219 
220         // Get chi-square of the best parameters set for the given set of
221         // observations.
222         final double bestChi2N = getChi2N(lsp, regress);
223         final RealVector sigma = lsp.evaluate(regress).getSigma(1e-14);
224 
225         // Monte-Carlo (generates a grid of parameters).
226         final int mcRepeat = MONTE_CARLO_RUNS;
227         final int gridSize = (int) FastMath.sqrt(mcRepeat);
228 
229         // Parameters found for each of Monte-Carlo run.
230         // Index 0 = slope
231         // Index 1 = offset
232         // Index 2 = normalized chi2
233         final List<double[]> paramsAndChi2 = new ArrayList<double[]>(gridSize * gridSize);
234 
235         final double slopeRange = 10 * sigma.getEntry(0);
236         final double offsetRange = 10 * sigma.getEntry(1);
237         final double minSlope = slope - 0.5 * slopeRange;
238         final double minOffset = offset - 0.5 * offsetRange;
239         final double deltaSlope =  slopeRange/ gridSize;
240         final double deltaOffset = offsetRange / gridSize;
241         for (int i = 0; i < gridSize; i++) {
242             final double s = minSlope + i * deltaSlope;
243             for (int j = 0; j < gridSize; j++) {
244                 final double o = minOffset + j * deltaOffset;
245                 final double chi2N = getChi2N(lsp,
246                         new ArrayRealVector(new double[] {s, o}, false));
247 
248                 paramsAndChi2.add(new double[] {s, o, chi2N});
249             }
250         }
251 
252         // Output (for use with "gnuplot").
253 
254         // Some info.
255 
256         // For plotting separately sets of parameters that have a large chi2.
257         final double chi2NPlusOne = bestChi2N + 1;
258         int numLarger = 0;
259 
260         final String lineFmt = "%+.10e %+.10e   %.8e\n";
261 
262         // Point with smallest chi-square.
263         System.out.printf(lineFmt, regress.getEntry(0), regress.getEntry(1), bestChi2N);
264         System.out.println(); // Empty line.
265 
266         // Points within the confidence interval.
267         for (double[] d : paramsAndChi2) {
268             if (d[2] <= chi2NPlusOne) {
269                 System.out.printf(lineFmt, d[0], d[1], d[2]);
270             }
271         }
272         System.out.println(); // Empty line.
273 
274         // Points outside the confidence interval.
275         for (double[] d : paramsAndChi2) {
276             if (d[2] > chi2NPlusOne) {
277                 ++numLarger;
278                 System.out.printf(lineFmt, d[0], d[1], d[2]);
279             }
280         }
281         System.out.println(); // Empty line.
282 
283         System.out.println("# sigma=" + sigma.toString());
284         System.out.println("# " + numLarger + " sets filtered out");
285     }
286 
287     LeastSquaresBuilder builder(StraightLineProblem problem){
288         return new LeastSquaresBuilder()
289                 .model(problem.getModelFunction(), problem.getModelFunctionJacobian())
290                 .target(problem.target())
291                 .weight(new DiagonalMatrix(problem.weight()))
292                 //unused start point to avoid NPE
293                 .start(new double[2]);
294     }
295     /**
296      * @return the normalized chi-square.
297      */
298     private double getChi2N(LeastSquaresProblem lsp,
299                             RealVector params) {
300         final double cost = lsp.evaluate(params).getCost();
301         return cost * cost / (lsp.getObservationSize() - params.getDimension());
302     }
303 }
304