1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23 package org.hipparchus.analysis.function;
24
25 import org.hipparchus.analysis.UnivariateFunction;
26 import org.hipparchus.analysis.differentiation.DSFactory;
27 import org.hipparchus.analysis.differentiation.DerivativeStructure;
28 import org.hipparchus.exception.MathIllegalArgumentException;
29 import org.hipparchus.exception.NullArgumentException;
30 import org.hipparchus.util.FastMath;
31 import org.junit.Assert;
32 import org.junit.Test;
33
34
35
36
37 public class LogisticTest {
38 private final double EPS = Math.ulp(1d);
39
40 @Test(expected=MathIllegalArgumentException.class)
41 public void testPreconditions1() {
42 new Logistic(1, 0, 1, 1, 0, -1);
43 }
44
45 @Test(expected=MathIllegalArgumentException.class)
46 public void testPreconditions2() {
47 new Logistic(1, 0, 1, 1, 0, 0);
48 }
49
50 @Test
51 public void testCompareSigmoid() {
52 final UnivariateFunction sig = new Sigmoid();
53 final UnivariateFunction sigL = new Logistic(1, 0, 1, 1, 0, 1);
54
55 final double min = -2;
56 final double max = 2;
57 final int n = 100;
58 final double delta = (max - min) / n;
59 for (int i = 0; i < n; i++) {
60 final double x = min + i * delta;
61 Assert.assertEquals("x=" + x, sig.value(x), sigL.value(x), EPS);
62 }
63 }
64
65 @Test
66 public void testSomeValues() {
67 final double k = 4;
68 final double m = 5;
69 final double b = 2;
70 final double q = 3;
71 final double a = -1;
72 final double n = 2;
73
74 final UnivariateFunction f = new Logistic(k, m, b, q, a, n);
75
76 double x;
77 x = m;
78 Assert.assertEquals("x=" + x, a + (k - a) / FastMath.sqrt(1 + q), f.value(x), EPS);
79
80 x = Double.NEGATIVE_INFINITY;
81 Assert.assertEquals("x=" + x, a, f.value(x), EPS);
82
83 x = Double.POSITIVE_INFINITY;
84 Assert.assertEquals("x=" + x, k, f.value(x), EPS);
85 }
86
87 @Test
88 public void testCompareDerivativeSigmoid() {
89 final double k = 3;
90 final double a = 2;
91
92 final Logistic f = new Logistic(k, 0, 1, 1, a, 1);
93 final Sigmoid g = new Sigmoid(a, k);
94
95 final double min = -10;
96 final double max = 10;
97 final double n = 20;
98 final double delta = (max - min) / n;
99 final DSFactory factory = new DSFactory(1, 5);
100 for (int i = 0; i < n; i++) {
101 final DerivativeStructure x = factory.variable(0, min + i * delta);
102 for (int order = 0; order <= x.getOrder(); ++order) {
103 Assert.assertEquals("x=" + x.getValue(),
104 g.value(x).getPartialDerivative(order),
105 f.value(x).getPartialDerivative(order),
106 3.0e-15);
107 }
108 }
109 }
110
111 @Test(expected=NullArgumentException.class)
112 public void testParametricUsage1() {
113 final Logistic.Parametric g = new Logistic.Parametric();
114 g.value(0, null);
115 }
116
117 @Test(expected=MathIllegalArgumentException.class)
118 public void testParametricUsage2() {
119 final Logistic.Parametric g = new Logistic.Parametric();
120 g.value(0, new double[] {0});
121 }
122
123 @Test(expected=NullArgumentException.class)
124 public void testParametricUsage3() {
125 final Logistic.Parametric g = new Logistic.Parametric();
126 g.gradient(0, null);
127 }
128
129 @Test(expected=MathIllegalArgumentException.class)
130 public void testParametricUsage4() {
131 final Logistic.Parametric g = new Logistic.Parametric();
132 g.gradient(0, new double[] {0});
133 }
134
135 @Test(expected=MathIllegalArgumentException.class)
136 public void testParametricUsage5() {
137 final Logistic.Parametric g = new Logistic.Parametric();
138 g.value(0, new double[] {1, 0, 1, 1, 0 ,0});
139 }
140
141 @Test(expected=MathIllegalArgumentException.class)
142 public void testParametricUsage6() {
143 final Logistic.Parametric g = new Logistic.Parametric();
144 g.gradient(0, new double[] {1, 0, 1, 1, 0 ,0});
145 }
146
147 @Test
148 public void testGradientComponent0Component4() {
149 final double k = 3;
150 final double a = 2;
151
152 final Logistic.Parametric f = new Logistic.Parametric();
153
154 final Sigmoid.Parametric g = new Sigmoid.Parametric();
155
156 final double x = 0.12345;
157 final double[] gf = f.gradient(x, new double[] {k, 0, 1, 1, a, 1});
158 final double[] gg = g.gradient(x, new double[] {a, k});
159
160 Assert.assertEquals(gg[0], gf[4], EPS);
161 Assert.assertEquals(gg[1], gf[0], EPS);
162 }
163
164 @Test
165 public void testGradientComponent5() {
166 final double m = 1.2;
167 final double k = 3.4;
168 final double a = 2.3;
169 final double q = 0.567;
170 final double b = -FastMath.log(q);
171 final double n = 3.4;
172
173 final Logistic.Parametric f = new Logistic.Parametric();
174
175 final double x = m - 1;
176 final double qExp1 = 2;
177
178 final double[] gf = f.gradient(x, new double[] {k, m, b, q, a, n});
179
180 Assert.assertEquals((k - a) * FastMath.log(qExp1) / (n * n * FastMath.pow(qExp1, 1 / n)),
181 gf[5], EPS);
182 }
183
184 @Test
185 public void testGradientComponent1Component2Component3() {
186 final double m = 1.2;
187 final double k = 3.4;
188 final double a = 2.3;
189 final double b = 0.567;
190 final double q = 1 / FastMath.exp(b * m);
191 final double n = 3.4;
192
193 final Logistic.Parametric f = new Logistic.Parametric();
194
195 final double x = 0;
196 final double qExp1 = 2;
197
198 final double[] gf = f.gradient(x, new double[] {k, m, b, q, a, n});
199
200 final double factor = (a - k) / (n * FastMath.pow(qExp1, 1 / n + 1));
201 Assert.assertEquals(factor * b, gf[1], EPS);
202 Assert.assertEquals(factor * m, gf[2], EPS);
203 Assert.assertEquals(factor / q, gf[3], EPS);
204 }
205 }