View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  /*
19   * This is not the original file distributed by the Apache Software Foundation
20   * It has been modified by the Hipparchus project
21   */
22  package org.hipparchus.analysis.solvers;
23  
24  import java.lang.reflect.Field;
25  
26  import org.hipparchus.analysis.QuinticFunction;
27  import org.hipparchus.analysis.UnivariateFunction;
28  import org.hipparchus.analysis.XMinus5Function;
29  import org.hipparchus.analysis.function.Sin;
30  import org.hipparchus.exception.MathIllegalArgumentException;
31  import org.hipparchus.util.FastMath;
32  import org.junit.Assert;
33  import org.junit.Test;
34  
35  /**
36   * Base class for root-finding algorithms tests derived from
37   * {@link BaseSecantSolver}.
38   *
39   */
40  public abstract class BaseSecantSolverAbstractTest {
41      /** Returns the solver to use to perform the tests.
42       * @return the solver to use to perform the tests
43       */
44      protected abstract UnivariateSolver getSolver();
45  
46      /** Returns the expected number of evaluations for the
47       * {@link #testQuinticZero} unit test. A value of {@code -1} indicates that
48       * the test should be skipped for that solver.
49       * @return the expected number of evaluations for the
50       * {@link #testQuinticZero} unit test
51       */
52      protected abstract int[] getQuinticEvalCounts();
53  
54      @Test
55      public void testSinZero() {
56          // The sinus function is behaved well around the root at pi. The second
57          // order derivative is zero, which means linear approximating methods
58          // still converge quadratically.
59          UnivariateFunction f = new Sin();
60          double result;
61          UnivariateSolver solver = getSolver();
62  
63          result = solver.solve(100, f, 3, 4);
64          //System.out.println(
65          //    "Root: " + result + " Evaluations: " + solver.getEvaluations());
66          Assert.assertEquals(result, FastMath.PI, solver.getAbsoluteAccuracy());
67          Assert.assertTrue(solver.getEvaluations() <= 6);
68          result = solver.solve(100, f, 1, 4);
69          //System.out.println(
70          //    "Root: " + result + " Evaluations: " + solver.getEvaluations());
71          Assert.assertEquals(result, FastMath.PI, solver.getAbsoluteAccuracy());
72          Assert.assertTrue(solver.getEvaluations() <= 7);
73      }
74  
75      @Test
76      public void testQuinticZero() {
77          // The quintic function has zeros at 0, +-0.5 and +-1.
78          // Around the root of 0 the function is well behaved, with a second
79          // derivative of zero a 0.
80          // The other roots are less well to find, in particular the root at 1,
81          // because the function grows fast for x>1.
82          // The function has extrema (first derivative is zero) at 0.27195613
83          // and 0.82221643, intervals containing these values are harder for
84          // the solvers.
85          UnivariateFunction f = new QuinticFunction();
86          double result;
87          UnivariateSolver solver = getSolver();
88          double atol = solver.getAbsoluteAccuracy();
89          int[] counts = getQuinticEvalCounts();
90  
91          // Tests data: initial bounds, and expected solution, per test case.
92          double[][] testsData = {{-0.2,  0.2,  0.0},
93                                  {-0.1,  0.3,  0.0},
94                                  {-0.3,  0.45, 0.0},
95                                  { 0.3,  0.7,  0.5},
96                                  { 0.2,  0.6,  0.5},
97                                  { 0.05, 0.95, 0.5},
98                                  { 0.85, 1.25, 1.0},
99                                  { 0.8,  1.2,  1.0},
100                                 { 0.85, 1.75, 1.0},
101                                 { 0.55, 1.45, 1.0},
102                                 { 0.85, 5.0,  1.0},
103                                };
104         int maxIter = 500;
105 
106         for(int i = 0; i < testsData.length; i++) {
107             // Skip test, if needed.
108             if (counts[i] == -1) continue;
109 
110             // Compute solution.
111             double[] testData = testsData[i];
112             result = solver.solve(maxIter, f, testData[0], testData[1]);
113             //System.out.println(
114             //    "Root: " + result + " Evaluations: " + solver.getEvaluations());
115 
116             // Check solution.
117             Assert.assertEquals(result, testData[2], atol);
118             Assert.assertTrue("" + solver.getEvaluations() + " <= " + (counts[i] + 1),
119                     solver.getEvaluations() <= counts[i] + 1);
120         }
121     }
122 
123     @Test
124     public void testRootEndpoints() {
125         UnivariateFunction f = new XMinus5Function();
126         UnivariateSolver solver = getSolver();
127 
128         // End-point is root. This should be a special case in the solver, and
129         // the initial end-point should be returned exactly.
130         double result = solver.solve(100, f, 5.0, 6.0);
131         Assert.assertEquals(5.0, result, 0.0);
132 
133         result = solver.solve(100, f, 4.0, 5.0);
134         Assert.assertEquals(5.0, result, 0.0);
135 
136         result = solver.solve(100, f, 5.0, 6.0, 5.5);
137         Assert.assertEquals(5.0, result, 0.0);
138 
139         result = solver.solve(100, f, 4.0, 5.0, 4.5);
140         Assert.assertEquals(5.0, result, 0.0);
141     }
142 
143     @Test
144     public void testCloseEndpoints() {
145         UnivariateFunction f = new XMinus5Function();
146         UnivariateSolver solver = getSolver();
147 
148         double result = solver.solve(100, f, 5.0, FastMath.nextUp(5.0));
149         Assert.assertEquals(5.0, result, 0.0);
150 
151         result = solver.solve(100, f, FastMath.nextDown(5.0), 5.0);
152         Assert.assertEquals(5.0, result, 0.0);
153     }
154 
155     @Test
156     public void testBadEndpoints() {
157         UnivariateFunction f = new Sin();
158         UnivariateSolver solver = getSolver();
159         try {  // bad interval
160             solver.solve(100, f, 1, -1);
161             Assert.fail("Expecting MathIllegalArgumentException - bad interval");
162         } catch (MathIllegalArgumentException ex) {
163             // expected
164         }
165         try {  // no bracket
166             solver.solve(100, f, 1, 1.5);
167             Assert.fail("Expecting MathIllegalArgumentException - non-bracketing");
168         } catch (MathIllegalArgumentException ex) {
169             // expected
170         }
171         try {  // no bracket
172             solver.solve(100, f, 1, 1.5, 1.2);
173             Assert.fail("Expecting MathIllegalArgumentException - non-bracketing");
174         } catch (MathIllegalArgumentException ex) {
175             // expected
176         }
177     }
178 
179     @Test
180     public void testSolutionLeftSide() {
181         UnivariateFunction f = new Sin();
182         UnivariateSolver solver = getSolver();
183         double left = -1.5;
184         double right = 0.05;
185         for(int i = 0; i < 10; i++) {
186             // Test whether the allowed solutions are taken into account.
187             double solution = getSolution(solver, 100, f, left, right, AllowedSolution.LEFT_SIDE);
188             if (!Double.isNaN(solution)) {
189                 Assert.assertTrue(solution <= 0.0);
190             }
191 
192             // Prepare for next test.
193             left -= 0.1;
194             right += 0.3;
195         }
196     }
197 
198     @Test
199     public void testSolutionRightSide() {
200         UnivariateFunction f = new Sin();
201         UnivariateSolver solver = getSolver();
202         double left = -1.5;
203         double right = 0.05;
204         for(int i = 0; i < 10; i++) {
205             // Test whether the allowed solutions are taken into account.
206             double solution = getSolution(solver, 100, f, left, right, AllowedSolution.RIGHT_SIDE);
207             if (!Double.isNaN(solution)) {
208                 Assert.assertTrue(solution >= 0.0);
209             }
210 
211             // Prepare for next test.
212             left -= 0.1;
213             right += 0.3;
214         }
215     }
216     @Test
217     public void testSolutionBelowSide() {
218         UnivariateFunction f = new Sin();
219         UnivariateSolver solver = getSolver();
220         double left = -1.5;
221         double right = 0.05;
222         for(int i = 0; i < 10; i++) {
223             // Test whether the allowed solutions are taken into account.
224             double solution = getSolution(solver, 100, f, left, right, AllowedSolution.BELOW_SIDE);
225             if (!Double.isNaN(solution)) {
226                 Assert.assertTrue(f.value(solution) <= 0.0);
227             }
228 
229             // Prepare for next test.
230             left -= 0.1;
231             right += 0.3;
232         }
233     }
234 
235     @Test
236     public void testSolutionAboveSide() {
237         UnivariateFunction f = new Sin();
238         UnivariateSolver solver = getSolver();
239         double left = -1.5;
240         double right = 0.05;
241         for(int i = 0; i < 10; i++) {
242             // Test whether the allowed solutions are taken into account.
243             double solution = getSolution(solver, 100, f, left, right, AllowedSolution.ABOVE_SIDE);
244             if (!Double.isNaN(solution)) {
245                 Assert.assertTrue(f.value(solution) >= 0.0);
246             }
247 
248             // Prepare for next test.
249             left -= 0.1;
250             right += 0.3;
251         }
252     }
253 
254     private double getSolution(UnivariateSolver solver, int maxEval, UnivariateFunction f,
255                                double left, double right, AllowedSolution allowedSolution) {
256         try {
257             @SuppressWarnings("unchecked")
258             BracketedUnivariateSolver<UnivariateFunction> bracketing =
259             (BracketedUnivariateSolver<UnivariateFunction>) solver;
260             return bracketing.solve(100, f, left, right, allowedSolution);
261         } catch (ClassCastException cce) {
262             double baseRoot = solver.solve(maxEval, f, left, right);
263             if ((baseRoot <= left) || (baseRoot >= right)) {
264                 // the solution slipped out of interval
265                 return Double.NaN;
266             }
267             PegasusSolver bracketing =
268                     new PegasusSolver(solver.getRelativeAccuracy(), solver.getAbsoluteAccuracy(),
269                                       solver.getFunctionValueAccuracy());
270             return UnivariateSolverUtils.forceSide(maxEval - solver.getEvaluations(),
271                                                        f, bracketing, baseRoot, left, right,
272                                                        allowedSolution);
273         }
274     }
275 
276     protected void checktype(UnivariateSolver solver, BaseSecantSolver.Method expected) {
277         try {
278             Field methodField = BaseSecantSolver.class.getDeclaredField("method");
279             methodField.setAccessible(true);
280             BaseSecantSolver.Method method = (BaseSecantSolver.Method) methodField.get(solver);
281             Assert.assertEquals(expected, method);
282         } catch (IllegalAccessException | NoSuchFieldException | SecurityException e) {
283             Assert.fail(e.getLocalizedMessage());
284         }
285     }
286 
287 }