View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  /*
19   * This is not the original file distributed by the Apache Software Foundation
20   * It has been modified by the Hipparchus project
21   */
22  
23  package org.hipparchus.analysis.function;
24  
25  import org.hipparchus.analysis.UnivariateFunction;
26  import org.hipparchus.analysis.differentiation.DSFactory;
27  import org.hipparchus.analysis.differentiation.DerivativeStructure;
28  import org.hipparchus.exception.MathIllegalArgumentException;
29  import org.hipparchus.exception.NullArgumentException;
30  import org.hipparchus.util.FastMath;
31  import org.junit.Assert;
32  import org.junit.Test;
33  
34  /**
35   * Test for class {@link Logistic}.
36   */
37  public class LogisticTest {
38      private final double EPS = Math.ulp(1d);
39  
40      @Test(expected=MathIllegalArgumentException.class)
41      public void testPreconditions1() {
42          new Logistic(1, 0, 1, 1, 0, -1);
43      }
44  
45      @Test(expected=MathIllegalArgumentException.class)
46      public void testPreconditions2() {
47          new Logistic(1, 0, 1, 1, 0, 0);
48      }
49  
50      @Test
51      public void testCompareSigmoid() {
52          final UnivariateFunction sig = new Sigmoid();
53          final UnivariateFunction sigL = new Logistic(1, 0, 1, 1, 0, 1);
54  
55          final double min = -2;
56          final double max = 2;
57          final int n = 100;
58          final double delta = (max - min) / n;
59          for (int i = 0; i < n; i++) {
60              final double x = min + i * delta;
61              Assert.assertEquals("x=" + x, sig.value(x), sigL.value(x), EPS);
62          }
63      }
64  
65      @Test
66      public void testSomeValues() {
67          final double k = 4;
68          final double m = 5;
69          final double b = 2;
70          final double q = 3;
71          final double a = -1;
72          final double n = 2;
73  
74          final UnivariateFunction f = new Logistic(k, m, b, q, a, n);
75  
76          double x;
77          x = m;
78          Assert.assertEquals("x=" + x, a + (k - a) / FastMath.sqrt(1 + q), f.value(x), EPS);
79  
80          x = Double.NEGATIVE_INFINITY;
81          Assert.assertEquals("x=" + x, a, f.value(x), EPS);
82  
83          x = Double.POSITIVE_INFINITY;
84          Assert.assertEquals("x=" + x, k, f.value(x), EPS);
85      }
86  
87      @Test
88      public void testCompareDerivativeSigmoid() {
89          final double k = 3;
90          final double a = 2;
91  
92          final Logistic f = new Logistic(k, 0, 1, 1, a, 1);
93          final Sigmoid g = new Sigmoid(a, k);
94  
95          final double min = -10;
96          final double max = 10;
97          final double n = 20;
98          final double delta = (max - min) / n;
99          final DSFactory factory = new DSFactory(1, 5);
100         for (int i = 0; i < n; i++) {
101             final DerivativeStructure x = factory.variable(0, min + i * delta);
102             for (int order = 0; order <= x.getOrder(); ++order) {
103                 Assert.assertEquals("x=" + x.getValue(),
104                                     g.value(x).getPartialDerivative(order),
105                                     f.value(x).getPartialDerivative(order),
106                                     3.0e-15);
107             }
108         }
109     }
110 
111     @Test(expected=NullArgumentException.class)
112     public void testParametricUsage1() {
113         final Logistic.Parametric g = new Logistic.Parametric();
114         g.value(0, null);
115     }
116 
117     @Test(expected=MathIllegalArgumentException.class)
118     public void testParametricUsage2() {
119         final Logistic.Parametric g = new Logistic.Parametric();
120         g.value(0, new double[] {0});
121     }
122 
123     @Test(expected=NullArgumentException.class)
124     public void testParametricUsage3() {
125         final Logistic.Parametric g = new Logistic.Parametric();
126         g.gradient(0, null);
127     }
128 
129     @Test(expected=MathIllegalArgumentException.class)
130     public void testParametricUsage4() {
131         final Logistic.Parametric g = new Logistic.Parametric();
132         g.gradient(0, new double[] {0});
133     }
134 
135     @Test(expected=MathIllegalArgumentException.class)
136     public void testParametricUsage5() {
137         final Logistic.Parametric g = new Logistic.Parametric();
138         g.value(0, new double[] {1, 0, 1, 1, 0 ,0});
139     }
140 
141     @Test(expected=MathIllegalArgumentException.class)
142     public void testParametricUsage6() {
143         final Logistic.Parametric g = new Logistic.Parametric();
144         g.gradient(0, new double[] {1, 0, 1, 1, 0 ,0});
145     }
146 
147     @Test
148     public void testGradientComponent0Component4() {
149         final double k = 3;
150         final double a = 2;
151 
152         final Logistic.Parametric f = new Logistic.Parametric();
153         // Compare using the "Sigmoid" function.
154         final Sigmoid.Parametric g = new Sigmoid.Parametric();
155 
156         final double x = 0.12345;
157         final double[] gf = f.gradient(x, new double[] {k, 0, 1, 1, a, 1});
158         final double[] gg = g.gradient(x, new double[] {a, k});
159 
160         Assert.assertEquals(gg[0], gf[4], EPS);
161         Assert.assertEquals(gg[1], gf[0], EPS);
162     }
163 
164     @Test
165     public void testGradientComponent5() {
166         final double m = 1.2;
167         final double k = 3.4;
168         final double a = 2.3;
169         final double q = 0.567;
170         final double b = -FastMath.log(q);
171         final double n = 3.4;
172 
173         final Logistic.Parametric f = new Logistic.Parametric();
174 
175         final double x = m - 1;
176         final double qExp1 = 2;
177 
178         final double[] gf = f.gradient(x, new double[] {k, m, b, q, a, n});
179 
180         Assert.assertEquals((k - a) * FastMath.log(qExp1) / (n * n * FastMath.pow(qExp1, 1 / n)),
181                             gf[5], EPS);
182     }
183 
184     @Test
185     public void testGradientComponent1Component2Component3() {
186         final double m = 1.2;
187         final double k = 3.4;
188         final double a = 2.3;
189         final double b = 0.567;
190         final double q = 1 / FastMath.exp(b * m);
191         final double n = 3.4;
192 
193         final Logistic.Parametric f = new Logistic.Parametric();
194 
195         final double x = 0;
196         final double qExp1 = 2;
197 
198         final double[] gf = f.gradient(x, new double[] {k, m, b, q, a, n});
199 
200         final double factor = (a - k) / (n * FastMath.pow(qExp1, 1 / n + 1));
201         Assert.assertEquals(factor * b, gf[1], EPS);
202         Assert.assertEquals(factor * m, gf[2], EPS);
203         Assert.assertEquals(factor / q, gf[3], EPS);
204     }
205 }