View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  /*
19   * This is not the original file distributed by the Apache Software Foundation
20   * It has been modified by the Hipparchus project
21   */
22  
23  package org.hipparchus.optim.nonlinear.vector.leastsquares;
24  
25  import static org.hamcrest.CoreMatchers.is;
26  
27  import java.util.ArrayList;
28  import java.util.List;
29  
30  import org.hamcrest.MatcherAssert;
31  import org.hipparchus.analysis.MultivariateMatrixFunction;
32  import org.hipparchus.analysis.MultivariateVectorFunction;
33  import org.hipparchus.exception.LocalizedCoreFormats;
34  import org.hipparchus.exception.MathIllegalArgumentException;
35  import org.hipparchus.exception.MathIllegalStateException;
36  import org.hipparchus.geometry.euclidean.twod.Vector2D;
37  import org.hipparchus.linear.DiagonalMatrix;
38  import org.hipparchus.linear.RealMatrix;
39  import org.hipparchus.linear.RealVector;
40  import org.hipparchus.optim.ConvergenceChecker;
41  import org.hipparchus.optim.nonlinear.vector.leastsquares.LeastSquaresOptimizer.Optimum;
42  import org.hipparchus.optim.nonlinear.vector.leastsquares.LeastSquaresProblem.Evaluation;
43  import org.hipparchus.util.FastMath;
44  import org.hipparchus.util.Incrementor;
45  import org.hipparchus.util.Precision;
46  import org.junit.Assert;
47  import org.junit.Test;
48  
49  /**
50   * <p>Some of the unit tests are re-implementations of the MINPACK <a
51   * href="http://www.netlib.org/minpack/ex/file17">file17</a> and <a
52   * href="http://www.netlib.org/minpack/ex/file22">file22</a> test files.
53   * The redistribution policy for MINPACK is available <a
54   * href="http://www.netlib.org/minpack/disclaimer">here</a>.
55   *
56   */
57  public class LevenbergMarquardtOptimizerTest
58      extends AbstractLeastSquaresOptimizerAbstractTest{
59  
60      public LeastSquaresBuilder builder(BevingtonProblem problem){
61          return base()
62                  .model(problem.getModelFunction(), problem.getModelFunctionJacobian());
63      }
64  
65      public LeastSquaresBuilder builder(CircleProblem problem){
66          return base()
67                  .model(problem.getModelFunction(), problem.getModelFunctionJacobian())
68                  .target(problem.target())
69                  .weight(new DiagonalMatrix(problem.weight()));
70      }
71  
72      @Override
73      public int getMaxIterations() {
74          return 25;
75      }
76  
77      @Override
78      public LeastSquaresOptimizer getOptimizer() {
79          return new LevenbergMarquardtOptimizer();
80      }
81  
82      @Override
83      @Test
84      public void testNonInvertible() {
85          /*
86           * Overrides the method from parent class, since the default singularity
87           * threshold (1e-14) does not trigger the expected exception.
88           */
89          LinearProblem problem = new LinearProblem(new double[][] {
90              {  1, 2, -3 },
91              {  2, 1,  3 },
92              { -3, 0, -9 }
93          }, new double[] { 1, 1, 1 });
94  
95          final Optimum optimum = optimizer.optimize(
96                                                     problem.getBuilder().maxIterations(20).build());
97  
98          //TODO check that it is a bad fit? Why the extra conditions?
99          Assert.assertTrue(FastMath.sqrt(problem.getTarget().length) * optimum.getRMS() > 0.6);
100 
101         try {
102             optimum.getCovariances(1.5e-14);
103             fail(optimizer);
104         } catch (MathIllegalArgumentException e) {
105             Assert.assertEquals(LocalizedCoreFormats.SINGULAR_MATRIX, e.getSpecifier());
106         }
107 
108     }
109 
110     @Test
111     public void testControlParameters() {
112         CircleVectorial circle = new CircleVectorial();
113         circle.addPoint( 30.0,  68.0);
114         circle.addPoint( 50.0,  -6.0);
115         circle.addPoint(110.0, -20.0);
116         circle.addPoint( 35.0,  15.0);
117         circle.addPoint( 45.0,  97.0);
118         checkEstimate(
119                 circle, 0.1, 10, 1.0e-14, 1.0e-16, 1.0e-10, false);
120         checkEstimate(
121                 circle, 0.1, 10, 1.0e-15, 1.0e-17, 1.0e-10, true);
122         checkEstimate(
123                 circle, 0.1,  5, 1.0e-15, 1.0e-16, 1.0e-10, true);
124         circle.addPoint(300, -300);
125         //wardev I changed true => false
126         //TODO why should this fail? It uses 15 evaluations.
127         checkEstimate(
128                 circle, 0.1, 20, 1.0e-18, 1.0e-16, 1.0e-10, false);
129     }
130 
131     private void checkEstimate(CircleVectorial circle,
132                                double initialStepBoundFactor, int maxCostEval,
133                                double costRelativeTolerance, double parRelativeTolerance,
134                                double orthoTolerance, boolean shouldFail) {
135         try {
136             final LevenbergMarquardtOptimizer optimizer = new LevenbergMarquardtOptimizer()
137                 .withInitialStepBoundFactor(initialStepBoundFactor)
138                 .withCostRelativeTolerance(costRelativeTolerance)
139                 .withParameterRelativeTolerance(parRelativeTolerance)
140                 .withOrthoTolerance(orthoTolerance)
141                 .withRankingThreshold(Precision.SAFE_MIN);
142 
143             final LeastSquaresProblem problem = builder(circle)
144                     .maxEvaluations(maxCostEval)
145                     .maxIterations(100)
146                     .start(new double[] { 98.680, 47.345 })
147                     .build();
148 
149             optimizer.optimize(problem);
150 
151             Assert.assertTrue(!shouldFail);
152             //TODO check it got the right answer
153 
154         } catch (MathIllegalArgumentException ee) {
155             Assert.assertTrue(shouldFail);
156         } catch (MathIllegalStateException ee) {
157             Assert.assertTrue(shouldFail);
158         }
159     }
160 
161     /**
162      * Non-linear test case: fitting of decay curve (from Chapter 8 of
163      * Bevington's textbook, "Data reduction and analysis for the physical sciences").
164      * XXX The expected ("reference") values may not be accurate and the tolerance too
165      * relaxed for this test to be currently really useful (the issue is under
166      * investigation).
167      */
168     @Test
169     public void testBevington() {
170         final double[][] dataPoints = {
171             // column 1 = times
172             { 15, 30, 45, 60, 75, 90, 105, 120, 135, 150,
173               165, 180, 195, 210, 225, 240, 255, 270, 285, 300,
174               315, 330, 345, 360, 375, 390, 405, 420, 435, 450,
175               465, 480, 495, 510, 525, 540, 555, 570, 585, 600,
176               615, 630, 645, 660, 675, 690, 705, 720, 735, 750,
177               765, 780, 795, 810, 825, 840, 855, 870, 885, },
178             // column 2 = measured counts
179             { 775, 479, 380, 302, 185, 157, 137, 119, 110, 89,
180               74, 61, 66, 68, 48, 54, 51, 46, 55, 29,
181               28, 37, 49, 26, 35, 29, 31, 24, 25, 35,
182               24, 30, 26, 28, 21, 18, 20, 27, 17, 17,
183               14, 17, 24, 11, 22, 17, 12, 10, 13, 16,
184               9, 9, 14, 21, 17, 13, 12, 18, 10, },
185         };
186         final double[] start = {10, 900, 80, 27, 225};
187 
188         final BevingtonProblem problem = new BevingtonProblem();
189 
190         final int len = dataPoints[0].length;
191         final double[] weights = new double[len];
192         for (int i = 0; i < len; i++) {
193             problem.addPoint(dataPoints[0][i],
194                              dataPoints[1][i]);
195 
196             weights[i] = 1 / dataPoints[1][i];
197         }
198 
199         final Optimum optimum = optimizer.optimize(
200                 builder(problem)
201                         .target(dataPoints[1])
202                         .weight(new DiagonalMatrix(weights))
203                         .start(start)
204                         .maxIterations(20)
205                         .build()
206         );
207 
208         final RealVector solution = optimum.getPoint();
209         final double[] expectedSolution = { 10.4, 958.3, 131.4, 33.9, 205.0 };
210 
211         final RealMatrix covarMatrix = optimum.getCovariances(1e-14);
212         final double[][] expectedCovarMatrix = {
213             { 3.38, -3.69, 27.98, -2.34, -49.24 },
214             { -3.69, 2492.26, 81.89, -69.21, -8.9 },
215             { 27.98, 81.89, 468.99, -44.22, -615.44 },
216             { -2.34, -69.21, -44.22, 6.39, 53.80 },
217             { -49.24, -8.9, -615.44, 53.8, 929.45 }
218         };
219 
220         final int numParams = expectedSolution.length;
221 
222         // Check that the computed solution is within the reference error range.
223         for (int i = 0; i < numParams; i++) {
224             final double error = FastMath.sqrt(expectedCovarMatrix[i][i]);
225             Assert.assertEquals("Parameter " + i, expectedSolution[i], solution.getEntry(i), error);
226         }
227 
228         // Check that each entry of the computed covariance matrix is within 10%
229         // of the reference matrix entry.
230         for (int i = 0; i < numParams; i++) {
231             for (int j = 0; j < numParams; j++) {
232                 Assert.assertEquals("Covariance matrix [" + i + "][" + j + "]",
233                                     expectedCovarMatrix[i][j],
234                                     covarMatrix.getEntry(i, j),
235                                     FastMath.abs(0.1 * expectedCovarMatrix[i][j]));
236             }
237         }
238 
239         // Check various measures of goodness-of-fit.
240         final double chi2 = optimum.getChiSquare();
241         final double cost = optimum.getCost();
242         final double rms = optimum.getRMS();
243         final double reducedChi2 = optimum.getReducedChiSquare(start.length);
244 
245         // XXX Values computed by the CM code: It would be better to compare
246         // with the results from another library.
247         final double expectedChi2 = 66.07852350839286;
248         final double expectedReducedChi2 = 1.2014277001525975;
249         final double expectedCost = 8.128869755900439;
250         final double expectedRms = 1.0582887010256337;
251 
252         final double tol = 1e14;
253         Assert.assertEquals(expectedChi2, chi2, tol);
254         Assert.assertEquals(expectedReducedChi2, reducedChi2, tol);
255         Assert.assertEquals(expectedCost, cost, tol);
256         Assert.assertEquals(expectedRms, rms, tol);
257     }
258 
259     @Test
260     public void testCircleFitting2() {
261         final double xCenter = 123.456;
262         final double yCenter = 654.321;
263         final double xSigma = 10;
264         final double ySigma = 15;
265         final double radius = 111.111;
266         // The test is extremely sensitive to the seed.
267         final long seed = 59421061L;
268         final RandomCirclePointGenerator factory
269             = new RandomCirclePointGenerator(xCenter, yCenter, radius,
270                                              xSigma, ySigma,
271                                              seed);
272         final CircleProblem circle = new CircleProblem(xSigma, ySigma);
273 
274         final int numPoints = 10;
275         for (Vector2D p : factory.generate(numPoints)) {
276             circle.addPoint(p.getX(), p.getY());
277         }
278 
279         // First guess for the center's coordinates and radius.
280         final double[] init = { 90, 659, 115 };
281 
282         Incrementor incrementor = new Incrementor();
283         final Optimum optimum = optimizer.optimize(
284                 LeastSquaresFactory.countEvaluations(builder(circle).maxIterations(50).start(init).build(),
285                                                      incrementor));
286 
287         final double[] paramFound = optimum.getPoint().toArray();
288 
289         // Retrieve errors estimation.
290         final double[] asymptoticStandardErrorFound = optimum.getSigma(1e-14).toArray();
291 
292         // Check that the parameters are found within the assumed error bars.
293         Assert.assertEquals(xCenter, paramFound[0], 3 * asymptoticStandardErrorFound[0]);
294         Assert.assertEquals(yCenter, paramFound[1], 3 * asymptoticStandardErrorFound[1]);
295         Assert.assertEquals(radius,  paramFound[2], 3 * asymptoticStandardErrorFound[2]);
296         Assert.assertTrue(incrementor.getCount() < 40);
297     }
298 
299     @Test
300     public void testParameterValidator() {
301         // Setup.
302         final double xCenter = 123.456;
303         final double yCenter = 654.321;
304         final double xSigma = 10;
305         final double ySigma = 15;
306         final double radius = 111.111;
307         final long seed = 3456789L;
308         final RandomCirclePointGenerator factory
309             = new RandomCirclePointGenerator(xCenter, yCenter, radius,
310                                              xSigma, ySigma,
311                                              seed);
312         final CircleProblem circle = new CircleProblem(xSigma, ySigma);
313 
314         final int numPoints = 10;
315         for (Vector2D p : factory.generate(numPoints)) {
316             circle.addPoint(p.getX(), p.getY());
317         }
318 
319         // First guess for the center's coordinates and radius.
320         final double[] init = { 90, 659, 115 };
321         final Optimum optimum
322             = optimizer.optimize(builder(circle).maxIterations(50).start(init).build());
323         final int numEval = optimum.getEvaluations();
324         Assert.assertTrue(numEval > 1);
325 
326         // Build a new problem with a validator that amounts to cheating.
327         final ParameterValidator cheatValidator
328             = new ParameterValidator() {
329                     public RealVector validate(RealVector params) {
330                         // Cheat: return the optimum found previously.
331                         return optimum.getPoint();
332                     }
333                 };
334 
335         final Optimum cheatOptimum
336             = optimizer.optimize(builder(circle).maxIterations(50).start(init).parameterValidator(cheatValidator).build());
337         final int cheatNumEval = cheatOptimum.getEvaluations();
338         Assert.assertTrue(cheatNumEval < numEval);
339         // System.out.println("n=" + numEval + " nc=" + cheatNumEval);
340     }
341 
342     @Test
343     public void testEvaluationCount() {
344         //setup
345         LeastSquaresProblem lsp = new LinearProblem(new double[][] {{1}}, new double[] {1})
346                 .getBuilder()
347                 .checker(new ConvergenceChecker<Evaluation>() {
348                     public boolean converged(int iteration, Evaluation previous, Evaluation current) {
349                         return true;
350                     }
351                 })
352                 .build();
353 
354         //action
355         Optimum optimum = optimizer.optimize(lsp);
356 
357         //verify
358         //check iterations and evaluations are not switched.
359         MatcherAssert.assertThat(optimum.getIterations(), is(1));
360         MatcherAssert.assertThat(optimum.getEvaluations(), is(2));
361     }
362 
363     private static class BevingtonProblem {
364         private List<Double> time;
365         private List<Double> count;
366 
367         public BevingtonProblem() {
368             time = new ArrayList<Double>();
369             count = new ArrayList<Double>();
370         }
371 
372         public void addPoint(double t, double c) {
373             time.add(t);
374             count.add(c);
375         }
376 
377         public MultivariateVectorFunction getModelFunction() {
378             return new MultivariateVectorFunction() {
379                 public double[] value(double[] params) {
380                     double[] values = new double[time.size()];
381                     for (int i = 0; i < values.length; ++i) {
382                         final double t = time.get(i);
383                         values[i] = params[0] +
384                             params[1] * FastMath.exp(-t / params[3]) +
385                             params[2] * FastMath.exp(-t / params[4]);
386                     }
387                     return values;
388                 }
389             };
390         }
391 
392         public MultivariateMatrixFunction getModelFunctionJacobian() {
393             return new MultivariateMatrixFunction() {
394                 public double[][] value(double[] params) {
395                     double[][] jacobian = new double[time.size()][5];
396 
397                     for (int i = 0; i < jacobian.length; ++i) {
398                         final double t = time.get(i);
399                         jacobian[i][0] = 1;
400 
401                         final double p3 =  params[3];
402                         final double p4 =  params[4];
403                         final double tOp3 = t / p3;
404                         final double tOp4 = t / p4;
405                         jacobian[i][1] = FastMath.exp(-tOp3);
406                         jacobian[i][2] = FastMath.exp(-tOp4);
407                         jacobian[i][3] = params[1] * FastMath.exp(-tOp3) * tOp3 / p3;
408                         jacobian[i][4] = params[2] * FastMath.exp(-tOp4) * tOp4 / p4;
409                     }
410                     return jacobian;
411                 }
412             };
413         }
414     }
415 }