View Javadoc
1   /*
2    * Licensed to the Hipparchus project under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The Hipparchus project licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.hipparchus.analysis.differentiation;
18  
19  import org.hipparchus.Field;
20  import org.hipparchus.CalculusFieldElement;
21  import org.hipparchus.CalculusFieldElementAbstractTest;
22  import org.hipparchus.UnitTestUtils;
23  import org.hipparchus.analysis.FieldUnivariateFunction;
24  import org.hipparchus.exception.LocalizedCoreFormats;
25  import org.hipparchus.exception.MathIllegalArgumentException;
26  import org.hipparchus.util.FastMath;
27  import org.hipparchus.util.FieldSinCos;
28  import org.hipparchus.util.MathArrays;
29  import org.junit.Assert;
30  import org.junit.Test;
31  
32  /**
33   * Test for class {@link UnivariateDerivative}.
34   */
35  public class GradientTest extends CalculusFieldElementAbstractTest<Gradient> {
36  
37      @Override
38      protected Gradient build(final double x) {
39          // the function is really a two variables function : f(x) = g(x, 0) with g(x, y) = x + y / 1024
40          return new Gradient(x, 1.0, FastMath.scalb(1.0, -10));
41      }
42  
43      @Test
44      public void testGetGradient() {
45          Gradient g = new Gradient(-0.5, 2.5, 10.0, -1.0);
46          Assert.assertEquals(-0.5, g.getReal(), 1.0e-15);
47          Assert.assertEquals(-0.5, g.getValue(), 1.0e-15);
48          Assert.assertEquals(+2.5, g.getGradient()[0], 1.0e-15);
49          Assert.assertEquals(10.0, g.getGradient()[1], 1.0e-15);
50          Assert.assertEquals(-1.0, g.getGradient()[2], 1.0e-15);
51          Assert.assertEquals(+2.5, g.getPartialDerivative(0), 1.0e-15);
52          Assert.assertEquals(10.0, g.getPartialDerivative(1), 1.0e-15);
53          Assert.assertEquals(-1.0, g.getPartialDerivative(2), 1.0e-15);
54          Assert.assertEquals(3, g.getFreeParameters());
55          try {
56              g.getPartialDerivative(-1);
57              Assert.fail("an exception should have been thrown");
58          } catch (MathIllegalArgumentException miae) {
59              Assert.assertEquals(LocalizedCoreFormats.OUT_OF_RANGE_SIMPLE, miae.getSpecifier());
60          }
61          try {
62              g.getPartialDerivative(+3);
63              Assert.fail("an exception should have been thrown");
64          } catch (MathIllegalArgumentException miae) {
65              Assert.assertEquals(LocalizedCoreFormats.OUT_OF_RANGE_SIMPLE, miae.getSpecifier());
66          }
67      }
68  
69      @Test
70      public void testConstant() {
71          Gradient g = Gradient.constant(5, -4.5);
72          Assert.assertEquals(5, g.getFreeParameters());
73          Assert.assertEquals(-4.5, g.getValue(), 1.0e-15);
74          for (int i = 0 ; i < g.getFreeParameters(); ++i) {
75              Assert.assertEquals(0.0, g.getPartialDerivative(i), 1.0e-15);
76          }
77      }
78  
79      @Test
80      public void testVariable() {
81          Gradient g = Gradient.variable(5, 1, -4.5);
82          Assert.assertEquals(5, g.getFreeParameters());
83          Assert.assertEquals(-4.5, g.getValue(), 1.0e-15);
84          for (int i = 0 ; i < g.getFreeParameters(); ++i) {
85              Assert.assertEquals(i == 1 ? 1.0 : 0.0, g.getPartialDerivative(i), 1.0e-15);
86          }
87      }
88  
89      @Test
90      public void testDoublePow() {
91          Assert.assertSame(build(3).getField().getZero(), Gradient.pow(0.0, build(1.5)));
92          Gradient g = Gradient.pow(2.0, build(1.5));
93          DSFactory factory = new DSFactory(2, 1);
94          DerivativeStructure ds = factory.constant(2.0).pow(factory.build(1.5, 1.0, FastMath.scalb(1.0, -10)));
95          Assert.assertEquals(ds.getValue(), g.getValue(), 1.0e-15);
96          final int[] indices = new int[ds.getFreeParameters()];
97          for (int i = 0; i < g.getFreeParameters(); ++i) {
98              indices[i] = 1;
99              Assert.assertEquals(ds.getPartialDerivative(indices), g.getPartialDerivative(i), 1.0e-15);
100             indices[i] = 0;
101         }
102     }
103 
104     @Test
105     public void testTaylor() {
106         Assert.assertEquals(2.75, new Gradient(2, 1, 0.125).taylor(0.5, 2.0), 1.0e-15);
107     }
108 
109     @Test
110     public void testOrder() {
111         Assert.assertEquals(1, new Gradient(2,  1, 0.125).getOrder());
112     }
113 
114     @Test
115     public void testGetPartialDerivative() {
116         final Gradient g = new Gradient(2,  1, 0.125);
117         Assert.assertEquals(2.0,   g.getPartialDerivative(0, 0), 1.0e-15); // f(x,y)
118         Assert.assertEquals(1.0,   g.getPartialDerivative(1, 0), 1.0e-15); // ∂f/∂x
119         Assert.assertEquals(0.125, g.getPartialDerivative(0, 1), 1.0e-15); // ∂f/∂y
120     }
121 
122     @Test
123     public void testGetPartialDerivativeErrors() {
124         final Gradient g = new Gradient(2,  1, 0.125);
125         try {
126             g.getPartialDerivative(0, 0, 0);
127             Assert.fail("an exception should have been thrown");
128         } catch (MathIllegalArgumentException miae) {
129             Assert.assertEquals(LocalizedCoreFormats.DIMENSIONS_MISMATCH, miae.getSpecifier());
130             Assert.assertEquals(3, ((Integer) miae.getParts()[0]).intValue());
131             Assert.assertEquals(2, ((Integer) miae.getParts()[1]).intValue());
132         }
133         try {
134             g.getPartialDerivative(0, 5);
135             Assert.fail("an exception should have been thrown");
136         } catch (MathIllegalArgumentException miae) {
137             Assert.assertEquals(LocalizedCoreFormats.DERIVATION_ORDER_NOT_ALLOWED, miae.getSpecifier());
138             Assert.assertEquals(5, ((Integer) miae.getParts()[0]).intValue());
139         }
140         try {
141             g.getPartialDerivative(1, 1);
142             Assert.fail("an exception should have been thrown");
143         } catch (MathIllegalArgumentException miae) {
144             Assert.assertEquals(LocalizedCoreFormats.DERIVATION_ORDER_NOT_ALLOWED, miae.getSpecifier());
145             Assert.assertEquals(1, ((Integer) miae.getParts()[0]).intValue());
146         }
147     }
148 
149     @Test
150     public void testHashcode() {
151         Assert.assertEquals(1608501298, new Gradient(2, 1, -0.25).hashCode());
152     }
153 
154     @Test
155     public void testEquals() {
156         Gradient g = new Gradient(12, -34, 56);
157         Assert.assertEquals(g, g);
158         Assert.assertNotEquals(g, "");
159         Assert.assertEquals(g, new Gradient(12, -34, 56));
160         Assert.assertNotEquals(g, new Gradient(21, -34, 56));
161         Assert.assertNotEquals(g, new Gradient(12, -43, 56));
162         Assert.assertNotEquals(g, new Gradient(12, -34, 65));
163         Assert.assertNotEquals(g, new Gradient(21, -43, 65));
164     }
165 
166     @Test
167     public void testRunTimeClass() {
168         Field<Gradient> field = build(0.0).getField();
169         Assert.assertEquals(Gradient.class, field.getRuntimeClass());
170     }
171 
172     @Test
173     public void testConversion() {
174         Gradient gA = new Gradient(-0.5, 2.5, 4.5);
175         DerivativeStructure ds = gA.toDerivativeStructure();
176         Assert.assertEquals(2, ds.getFreeParameters());
177         Assert.assertEquals(1, ds.getOrder());
178         Assert.assertEquals(-0.5, ds.getValue(), 1.0e-15);
179         Assert.assertEquals(-0.5, ds.getPartialDerivative(0, 0), 1.0e-15);
180         Assert.assertEquals( 2.5, ds.getPartialDerivative(1, 0), 1.0e-15);
181         Assert.assertEquals( 4.5, ds.getPartialDerivative(0, 1), 1.0e-15);
182         Gradient gB = new Gradient(ds);
183         Assert.assertNotSame(gA, gB);
184         Assert.assertEquals(gA, gB);
185         try {
186             new Gradient(new DSFactory(1, 2).variable(0, 1.0));
187             Assert.fail("an exception should have been thrown");
188         } catch (MathIllegalArgumentException miae) {
189             Assert.assertEquals(LocalizedCoreFormats.DIMENSIONS_MISMATCH, miae.getSpecifier());
190         }
191     }
192 
193     @Test
194     public void testNewInstance() {
195         Gradient g = build(5.25);
196         Assert.assertEquals(5.25, g.getValue(), 1.0e-15);
197         Assert.assertEquals(1.0,  g.getPartialDerivative(0), 1.0e-15);
198         Assert.assertEquals(0.0009765625,  g.getPartialDerivative(1), 1.0e-15);
199         Gradient newInstance = g.newInstance(7.5);
200         Assert.assertEquals(7.5, newInstance.getValue(), 1.0e-15);
201         Assert.assertEquals(0.0, newInstance.getPartialDerivative(0), 1.0e-15);
202         Assert.assertEquals(0.0, newInstance.getPartialDerivative(1), 1.0e-15);
203     }
204 
205     protected void checkAgainstDS(final double x, final FieldUnivariateFunction f) {
206         final Gradient xG = build(x);
207         final Gradient yG = f.value(xG);
208         final DerivativeStructure yDS = f.value(xG.toDerivativeStructure());
209         Assert.assertEquals(yDS.getFreeParameters(), yG.getFreeParameters());
210         Assert.assertEquals(yDS.getValue(), yG.getValue(), 1.0e-15 * FastMath.abs(yDS.getValue()));
211         final int[] indices = new int[yDS.getFreeParameters()];
212         for (int i = 0; i < yG.getFreeParameters(); ++i) {
213             indices[i] = 1;
214             Assert.assertEquals(yDS.getPartialDerivative(indices),
215                                 yG.getPartialDerivative(i),
216                                 4.0e-14* FastMath.abs(yDS.getPartialDerivative(indices)));
217             indices[i] = 0;
218         }
219     }
220 
221     @Test
222     public void testArithmeticVsDS() {
223         for (double x = -1.25; x < 1.25; x+= 0.5) {
224             checkAgainstDS(x,
225                            new FieldUnivariateFunction() {
226                                public <S extends CalculusFieldElement<S>> S value(S x) {
227                                    final S y = x.add(3).multiply(x).subtract(5).multiply(0.5);
228                                    return y.negate().divide(4).divide(x).add(y).subtract(x).multiply(2).reciprocal();
229                                }
230                            });
231         }
232     }
233 
234     @Test
235     public void testRemainderDoubleVsDS() {
236         for (double x = -1.25; x < 1.25; x+= 0.5) {
237             checkAgainstDS(x,
238                            new FieldUnivariateFunction() {
239                                public <S extends CalculusFieldElement<S>> S value(S x) {
240                                    return x.remainder(0.5);
241                                }
242                            });
243         }
244     }
245 
246     @Test
247     public void testRemainderGVsDS() {
248         for (double x = -1.25; x < 1.25; x+= 0.5) {
249             checkAgainstDS(x,
250                            new FieldUnivariateFunction() {
251                               public <S extends CalculusFieldElement<S>> S value(S x) {
252                                   return x.remainder(x.divide(0.7));
253                               }
254                            });
255         }
256     }
257 
258     @Test
259     public void testAbsVsDS() {
260         for (double x = -1.25; x < 1.25; x+= 0.5) {
261             checkAgainstDS(x,
262                            new FieldUnivariateFunction() {
263                                public <S extends CalculusFieldElement<S>> S value(S x) {
264                                    return x.abs();
265                                }
266                            });
267         }
268     }
269 
270     @Test
271     public void testHypotVsDS() {
272         for (double x = -3.25; x < 3.25; x+= 0.5) {
273             checkAgainstDS(x,
274                            new FieldUnivariateFunction() {
275                                public <S extends CalculusFieldElement<S>> S value(S x) {
276                                    return x.cos().multiply(5).hypot(x.sin().multiply(2));
277                                }
278                            });
279         }
280     }
281 
282     @Test
283     public void testAtan2VsDS() {
284         for (double x = -3.25; x < 3.25; x+= 0.5) {
285             checkAgainstDS(x,
286                            new FieldUnivariateFunction() {
287                                public <S extends CalculusFieldElement<S>> S value(S x) {
288                                    return x.cos().multiply(5).atan2(x.sin().multiply(2));
289                                }
290                            });
291         }
292     }
293 
294     @Test
295     public void testPowersVsDS() {
296         for (double x = -3.25; x < 3.25; x+= 0.5) {
297             checkAgainstDS(x,
298                            new FieldUnivariateFunction() {
299                                public <S extends CalculusFieldElement<S>> S value(S x) {
300                                    final FieldSinCos<S> sc = x.sinCos();
301                                    return x.pow(3.2).add(x.pow(2)).subtract(sc.cos().abs().pow(sc.sin()));
302                                }
303                            });
304         }
305     }
306 
307     @Test
308     public void testRootsVsDS() {
309         for (double x = 0.001; x < 3.25; x+= 0.5) {
310             checkAgainstDS(x,
311                            new FieldUnivariateFunction() {
312                                public <S extends CalculusFieldElement<S>> S value(S x) {
313                                    return x.rootN(5);//x.sqrt().add(x.cbrt()).subtract(x.rootN(5));
314                                }
315                            });
316         }
317     }
318 
319     @Test
320     public void testExpsLogsVsDS() {
321         for (double x = 2.5; x < 3.25; x+= 0.125) {
322             checkAgainstDS(x,
323                            new FieldUnivariateFunction() {
324                                public <S extends CalculusFieldElement<S>> S value(S x) {
325                                    return x.exp().add(x.multiply(0.5).expm1()).log().log10().log1p();
326                                }
327                            });
328         }
329     }
330 
331     @Test
332     public void testTrigonometryVsDS() {
333         for (double x = -3.25; x < 3.25; x+= 0.5) {
334             checkAgainstDS(x,
335                            new FieldUnivariateFunction() {
336                                public <S extends CalculusFieldElement<S>> S value(S x) {
337                                    return x.cos().multiply(x.sin()).atan().divide(12).asin().multiply(0.1).acos().tan();
338                                }
339                            });
340         }
341     }
342 
343     @Test
344     public void testHyperbolicVsDS() {
345         for (double x = -1.25; x < 1.25; x+= 0.5) {
346             checkAgainstDS(x,
347                            new FieldUnivariateFunction() {
348                                public <S extends CalculusFieldElement<S>> S value(S x) {
349                                    return x.cosh().multiply(x.sinh()).multiply(12).abs().acosh().asinh().divide(7).tanh().multiply(0.1).atanh();
350                                }
351                            });
352         }
353     }
354 
355     @Test
356     public void testConvertersVsDS() {
357         for (double x = -1.25; x < 1.25; x+= 0.5) {
358             checkAgainstDS(x,
359                            new FieldUnivariateFunction() {
360                                public <S extends CalculusFieldElement<S>> S value(S x) {
361                                    return x.multiply(5).toDegrees().subtract(x).toRadians();
362                                }
363                            });
364         }
365     }
366 
367     @Test
368     public void testLinearCombination2D2FVsDS() {
369         for (double x = -1.25; x < 1.25; x+= 0.5) {
370             checkAgainstDS(x,
371                            new FieldUnivariateFunction() {
372                                public <S extends CalculusFieldElement<S>> S value(S x) {
373                                    return x.linearCombination(1.0, x.multiply(0.9),
374                                                               2.0, x.multiply(0.8));
375                                }
376                            });
377         }
378     }
379 
380     @Test
381     public void testLinearCombination2F2FVsDS() {
382         for (double x = -1.25; x < 1.25; x+= 0.5) {
383             checkAgainstDS(x,
384                            new FieldUnivariateFunction() {
385                                public <S extends CalculusFieldElement<S>> S value(S x) {
386                                    return x.linearCombination(x.add(1), x.multiply(0.9),
387                                                               x.add(2), x.multiply(0.8));
388                                }
389                            });
390         }
391     }
392 
393     @Test
394     public void testLinearCombination3D3FVsDS() {
395         for (double x = -1.25; x < 1.25; x+= 0.5) {
396             checkAgainstDS(x,
397                            new FieldUnivariateFunction() {
398                                public <S extends CalculusFieldElement<S>> S value(S x) {
399                                    return x.linearCombination(1.0, x.multiply(0.9),
400                                                               2.0, x.multiply(0.8),
401                                                               3.0, x.multiply(0.7));
402                                }
403                            });
404         }
405     }
406 
407     @Test
408     public void testLinearCombination3F3FVsDS() {
409         for (double x = -1.25; x < 1.25; x+= 0.5) {
410             checkAgainstDS(x,
411                            new FieldUnivariateFunction() {
412                                public <S extends CalculusFieldElement<S>> S value(S x) {
413                                    return x.linearCombination(x.add(1), x.multiply(0.9),
414                                                               x.add(2), x.multiply(0.8),
415                                                               x.add(3), x.multiply(0.7));
416                                }
417                            });
418         }
419     }
420 
421     @Test
422     public void testLinearCombination4D4FVsDS() {
423         for (double x = -1.25; x < 1.25; x+= 0.5) {
424             checkAgainstDS(x,
425                            new FieldUnivariateFunction() {
426                                public <S extends CalculusFieldElement<S>> S value(S x) {
427                                    return x.linearCombination(1.0, x.multiply(0.9),
428                                                               2.0, x.multiply(0.8),
429                                                               3.0, x.multiply(0.7),
430                                                               4.0, x.multiply(0.6));
431                                }
432                            });
433         }
434     }
435 
436     @Test
437     public void testLinearCombination4F4FVsDS() {
438         for (double x = -1.25; x < 1.25; x+= 0.5) {
439             checkAgainstDS(x,
440                            new FieldUnivariateFunction() {
441                                public <S extends CalculusFieldElement<S>> S value(S x) {
442                                    return x.linearCombination(x.add(1), x.multiply(0.9),
443                                                               x.add(2), x.multiply(0.8),
444                                                               x.add(3), x.multiply(0.7),
445                                                               x.add(4), x.multiply(0.6));
446                                }
447                            });
448         }
449     }
450 
451     @Test
452     public void testLinearCombinationnDnFVsDS() {
453         for (double x = -1.25; x < 1.25; x+= 0.5) {
454             checkAgainstDS(x,
455                            new FieldUnivariateFunction() {
456                                public <S extends CalculusFieldElement<S>> S value(S x) {
457                                    final S[] b = MathArrays.buildArray(x.getField(), 4);
458                                    b[0] = x.add(0.9);
459                                    b[1] = x.add(0.8);
460                                    b[2] = x.add(0.7);
461                                    b[3] = x.add(0.6);
462                                    return x.linearCombination(new double[] { 1, 2, 3, 4 }, b);
463                                }
464                            });
465         }
466     }
467 
468     @Test
469     public void testLinearCombinationnFnFVsDS() {
470         for (double x = -1.25; x < 1.25; x+= 0.5) {
471             checkAgainstDS(x,
472                            new FieldUnivariateFunction() {
473                                public <S extends CalculusFieldElement<S>> S value(S x) {
474                                    final S[] a = MathArrays.buildArray(x.getField(), 4);
475                                    a[0] = x.add(1);
476                                    a[1] = x.add(2);
477                                    a[2] = x.add(3);
478                                    a[3] = x.add(4);
479                                    final S[] b = MathArrays.buildArray(x.getField(), 4);
480                                    b[0] = x.add(0.9);
481                                    b[1] = x.add(0.8);
482                                    b[2] = x.add(0.7);
483                                    b[3] = x.add(0.6);
484                                    return x.linearCombination(a, b);
485                                }
486                            });
487         }
488     }
489 
490     @Test
491     public void testSerialization() {
492         Gradient a = build(1.3);
493         Gradient b = (Gradient) UnitTestUtils.serializeAndRecover(a);
494         Assert.assertEquals(a, b);
495         Assert.assertNotSame(a, b);
496     }
497 
498     @Test
499     public void testZero() {
500         Gradient zero = build(17.0).getField().getZero();
501         Assert.assertEquals(0.0, zero.getValue(), 1.0e-15);
502         for (int i = 0; i < zero.getFreeParameters(); ++i) {
503             Assert.assertEquals(0.0, zero.getPartialDerivative(i), 1.0e-15);
504         }
505     }
506 
507     @Test
508     public void testOne() {
509         Gradient one = build(17.0).getField().getOne();
510         Assert.assertEquals(1.0, one.getValue(), 1.0e-15);
511         for (int i = 0; i < one.getFreeParameters(); ++i) {
512             Assert.assertEquals(0.0, one.getPartialDerivative(i), 1.0e-15);
513         }
514     }
515 
516 }