1   
2   
3   
4   
5   
6   
7   
8   
9   
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  package org.hipparchus.analysis.differentiation;
24  
25  import org.hipparchus.UnitTestUtils;
26  import org.hipparchus.exception.MathIllegalArgumentException;
27  import org.hipparchus.util.FastMath;
28  import org.junit.jupiter.api.Test;
29  
30  
31  
32  
33  
34  class GradientFunctionTest {
35  
36      @Test
37      void test2DDistance() {
38          EuclideanDistance f = new EuclideanDistance();
39          GradientFunction g = new GradientFunction(f);
40          for (double x = -10; x < 10; x += 0.5) {
41              for (double y = -10; y < 10; y += 0.5) {
42                  double[] point = new double[] { x, y };
43                  UnitTestUtils.customAssertEquals(f.gradient(point), g.value(point), 1.0e-15);
44              }
45          }
46      }
47  
48      @Test
49      void test3DDistance() {
50          EuclideanDistance f = new EuclideanDistance();
51          GradientFunction g = new GradientFunction(f);
52          for (double x = -10; x < 10; x += 0.5) {
53              for (double y = -10; y < 10; y += 0.5) {
54                  for (double z = -10; z < 10; z += 0.5) {
55                      double[] point = new double[] { x, y, z };
56                      UnitTestUtils.customAssertEquals(f.gradient(point), g.value(point), 1.0e-15);
57                  }
58              }
59          }
60      }
61  
62      private static class EuclideanDistance implements MultivariateDifferentiableFunction {
63  
64          @Override
65          public double value(double[] point) {
66              double d2 = 0;
67              for (double x : point) {
68                  d2 += x * x;
69              }
70              return FastMath.sqrt(d2);
71          }
72  
73          @Override
74          public DerivativeStructure value(DerivativeStructure[] point)
75              throws MathIllegalArgumentException {
76              DerivativeStructure d2 = point[0].getField().getZero();
77              for (DerivativeStructure x : point) {
78                  d2 = d2.add(x.square());
79              }
80              return d2.sqrt();
81          }
82  
83          public double[] gradient(double[] point) {
84              double[] gradient = new double[point.length];
85              double d = value(point);
86              for (int i = 0; i < point.length; ++i) {
87                  gradient[i] = point[i] / d;
88              }
89              return gradient;
90          }
91  
92      }
93  
94  }