1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23 package org.hipparchus.analysis.differentiation;
24
25 import org.hipparchus.UnitTestUtils;
26 import org.hipparchus.exception.MathIllegalArgumentException;
27 import org.hipparchus.util.FastMath;
28 import org.junit.Test;
29
30
31
32
33
34 public class GradientFunctionTest {
35
36 @Test
37 public void test2DDistance() {
38 EuclideanDistance f = new EuclideanDistance();
39 GradientFunction g = new GradientFunction(f);
40 for (double x = -10; x < 10; x += 0.5) {
41 for (double y = -10; y < 10; y += 0.5) {
42 double[] point = new double[] { x, y };
43 UnitTestUtils.assertEquals(f.gradient(point), g.value(point), 1.0e-15);
44 }
45 }
46 }
47
48 @Test
49 public void test3DDistance() {
50 EuclideanDistance f = new EuclideanDistance();
51 GradientFunction g = new GradientFunction(f);
52 for (double x = -10; x < 10; x += 0.5) {
53 for (double y = -10; y < 10; y += 0.5) {
54 for (double z = -10; z < 10; z += 0.5) {
55 double[] point = new double[] { x, y, z };
56 UnitTestUtils.assertEquals(f.gradient(point), g.value(point), 1.0e-15);
57 }
58 }
59 }
60 }
61
62 private static class EuclideanDistance implements MultivariateDifferentiableFunction {
63
64 @Override
65 public double value(double[] point) {
66 double d2 = 0;
67 for (double x : point) {
68 d2 += x * x;
69 }
70 return FastMath.sqrt(d2);
71 }
72
73 @Override
74 public DerivativeStructure value(DerivativeStructure[] point)
75 throws MathIllegalArgumentException {
76 DerivativeStructure d2 = point[0].getField().getZero();
77 for (DerivativeStructure x : point) {
78 d2 = d2.add(x.square());
79 }
80 return d2.sqrt();
81 }
82
83 public double[] gradient(double[] point) {
84 double[] gradient = new double[point.length];
85 double d = value(point);
86 for (int i = 0; i < point.length; ++i) {
87 gradient[i] = point[i] / d;
88 }
89 return gradient;
90 }
91
92 }
93
94 }