View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  /*
19   * This is not the original file distributed by the Apache Software Foundation
20   * It has been modified by the Hipparchus project
21   */
22  
23  package org.hipparchus.analysis.differentiation;
24  
25  import org.hipparchus.UnitTestUtils;
26  import org.hipparchus.exception.MathIllegalArgumentException;
27  import org.hipparchus.util.FastMath;
28  import org.junit.Test;
29  
30  
31  /**
32   * Test for class {@link GradientFunction}.
33   */
34  public class GradientFunctionTest {
35  
36      @Test
37      public void test2DDistance() {
38          EuclideanDistance f = new EuclideanDistance();
39          GradientFunction g = new GradientFunction(f);
40          for (double x = -10; x < 10; x += 0.5) {
41              for (double y = -10; y < 10; y += 0.5) {
42                  double[] point = new double[] { x, y };
43                  UnitTestUtils.assertEquals(f.gradient(point), g.value(point), 1.0e-15);
44              }
45          }
46      }
47  
48      @Test
49      public void test3DDistance() {
50          EuclideanDistance f = new EuclideanDistance();
51          GradientFunction g = new GradientFunction(f);
52          for (double x = -10; x < 10; x += 0.5) {
53              for (double y = -10; y < 10; y += 0.5) {
54                  for (double z = -10; z < 10; z += 0.5) {
55                      double[] point = new double[] { x, y, z };
56                      UnitTestUtils.assertEquals(f.gradient(point), g.value(point), 1.0e-15);
57                  }
58              }
59          }
60      }
61  
62      private static class EuclideanDistance implements MultivariateDifferentiableFunction {
63  
64          @Override
65          public double value(double[] point) {
66              double d2 = 0;
67              for (double x : point) {
68                  d2 += x * x;
69              }
70              return FastMath.sqrt(d2);
71          }
72  
73          @Override
74          public DerivativeStructure value(DerivativeStructure[] point)
75              throws MathIllegalArgumentException {
76              DerivativeStructure d2 = point[0].getField().getZero();
77              for (DerivativeStructure x : point) {
78                  d2 = d2.add(x.square());
79              }
80              return d2.sqrt();
81          }
82  
83          public double[] gradient(double[] point) {
84              double[] gradient = new double[point.length];
85              double d = value(point);
86              for (int i = 0; i < point.length; ++i) {
87                  gradient[i] = point[i] / d;
88              }
89              return gradient;
90          }
91  
92      }
93  
94  }