1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23 package org.hipparchus.analysis.function;
24
25 import org.hipparchus.analysis.FunctionUtils;
26 import org.hipparchus.analysis.UnivariateFunction;
27 import org.hipparchus.analysis.differentiation.DSFactory;
28 import org.hipparchus.analysis.differentiation.DerivativeStructure;
29 import org.hipparchus.analysis.differentiation.UnivariateDifferentiableFunction;
30 import org.hipparchus.exception.MathIllegalArgumentException;
31 import org.hipparchus.exception.NullArgumentException;
32 import org.hipparchus.random.RandomGenerator;
33 import org.hipparchus.random.Well1024a;
34 import org.hipparchus.util.FastMath;
35 import org.junit.Assert;
36 import org.junit.Test;
37
38
39
40
41 public class LogitTest {
42 private final double EPS = Math.ulp(1d);
43
44 @Test(expected=MathIllegalArgumentException.class)
45 public void testPreconditions1() {
46 final double lo = -1;
47 final double hi = 2;
48 final UnivariateFunction f = new Logit(lo, hi);
49
50 f.value(lo - 1);
51 }
52
53 @Test(expected=MathIllegalArgumentException.class)
54 public void testPreconditions2() {
55 final double lo = -1;
56 final double hi = 2;
57 final UnivariateFunction f = new Logit(lo, hi);
58
59 f.value(hi + 1);
60 }
61
62 @Test
63 public void testSomeValues() {
64 final double lo = 1;
65 final double hi = 2;
66 final UnivariateFunction f = new Logit(lo, hi);
67
68 Assert.assertEquals(Double.NEGATIVE_INFINITY, f.value(1), EPS);
69 Assert.assertEquals(Double.POSITIVE_INFINITY, f.value(2), EPS);
70 Assert.assertEquals(0, f.value(1.5), EPS);
71 }
72
73 @Test
74 public void testDerivative() {
75 final double lo = 1;
76 final double hi = 2;
77 final Logit f = new Logit(lo, hi);
78 final DerivativeStructure f15 = f.value(new DSFactory(1, 1).variable(0, 1.5));
79
80 Assert.assertEquals(4, f15.getPartialDerivative(1), EPS);
81 }
82
83 @Test
84 public void testDerivativeLargeArguments() {
85 final Logit f = new Logit(1, 2);
86
87 DSFactory factory = new DSFactory(1, 1);
88 for (double arg : new double[] {
89 Double.NEGATIVE_INFINITY, -Double.MAX_VALUE, -1e155, 1e155, Double.MAX_VALUE, Double.POSITIVE_INFINITY
90 }) {
91 try {
92 f.value(factory.variable(0, arg));
93 Assert.fail("an exception should have been thrown");
94 } catch (MathIllegalArgumentException ore) {
95
96 } catch (Exception e) {
97 Assert.fail("wrong exception caught: " + e.getMessage());
98 }
99 }
100 }
101
102 @Test
103 public void testDerivativesHighOrder() {
104 DerivativeStructure l = new Logit(1, 3).value(new DSFactory(1, 5).variable(0, 1.2));
105 Assert.assertEquals(-2.1972245773362193828, l.getPartialDerivative(0), 1.0e-16);
106 Assert.assertEquals(5.5555555555555555555, l.getPartialDerivative(1), 9.0e-16);
107 Assert.assertEquals(-24.691358024691358025, l.getPartialDerivative(2), 2.0e-14);
108 Assert.assertEquals(250.34293552812071331, l.getPartialDerivative(3), 2.0e-13);
109 Assert.assertEquals(-3749.4284407864654778, l.getPartialDerivative(4), 4.0e-12);
110 Assert.assertEquals(75001.270131585632282, l.getPartialDerivative(5), 8.0e-11);
111 }
112
113 @Test(expected=NullArgumentException.class)
114 public void testParametricUsage1() {
115 final Logit.Parametric g = new Logit.Parametric();
116 g.value(0, null);
117 }
118
119 @Test(expected=MathIllegalArgumentException.class)
120 public void testParametricUsage2() {
121 final Logit.Parametric g = new Logit.Parametric();
122 g.value(0, new double[] {0});
123 }
124
125 @Test(expected=NullArgumentException.class)
126 public void testParametricUsage3() {
127 final Logit.Parametric g = new Logit.Parametric();
128 g.gradient(0, null);
129 }
130
131 @Test(expected=MathIllegalArgumentException.class)
132 public void testParametricUsage4() {
133 final Logit.Parametric g = new Logit.Parametric();
134 g.gradient(0, new double[] {0});
135 }
136
137 @Test(expected=MathIllegalArgumentException.class)
138 public void testParametricUsage5() {
139 final Logit.Parametric g = new Logit.Parametric();
140 g.value(-1, new double[] {0, 1});
141 }
142
143 @Test(expected=MathIllegalArgumentException.class)
144 public void testParametricUsage6() {
145 final Logit.Parametric g = new Logit.Parametric();
146 g.value(2, new double[] {0, 1});
147 }
148
149 @Test
150 public void testParametricValue() {
151 final double lo = 2;
152 final double hi = 3;
153 final Logit f = new Logit(lo, hi);
154
155 final Logit.Parametric g = new Logit.Parametric();
156 Assert.assertEquals(f.value(2), g.value(2, new double[] {lo, hi}), 0);
157 Assert.assertEquals(f.value(2.34567), g.value(2.34567, new double[] {lo, hi}), 0);
158 Assert.assertEquals(f.value(3), g.value(3, new double[] {lo, hi}), 0);
159 }
160
161 @Test
162 public void testValueWithInverseFunction() {
163 final double lo = 2;
164 final double hi = 3;
165 final Logit f = new Logit(lo, hi);
166 final Sigmoid g = new Sigmoid(lo, hi);
167 RandomGenerator random = new Well1024a(0x49914cdd9f0b8db5l);
168 final UnivariateDifferentiableFunction id = FunctionUtils.compose((UnivariateDifferentiableFunction) g,
169 (UnivariateDifferentiableFunction) f);
170
171 DSFactory factory = new DSFactory(1, 1);
172 for (int i = 0; i < 10; i++) {
173 final double x = lo + random.nextDouble() * (hi - lo);
174 Assert.assertEquals(x, id.value(factory.variable(0, x)).getValue(), EPS);
175 }
176
177 Assert.assertEquals(lo, id.value(factory.variable(0, lo)).getValue(), EPS);
178 Assert.assertEquals(hi, id.value(factory.variable(0, hi)).getValue(), EPS);
179 }
180
181 @Test
182 public void testDerivativesWithInverseFunction() {
183 double[] epsilon = new double[] { 1.0e-20, 4.0e-16, 3.0e-15, 2.0e-11, 3.0e-9, 1.0e-6 };
184 final double lo = 2;
185 final double hi = 3;
186 final Logit f = new Logit(lo, hi);
187 final Sigmoid g = new Sigmoid(lo, hi);
188 RandomGenerator random = new Well1024a(0x96885e9c1f81cea5l);
189 final UnivariateDifferentiableFunction id =
190 FunctionUtils.compose((UnivariateDifferentiableFunction) g, (UnivariateDifferentiableFunction) f);
191 for (int maxOrder = 0; maxOrder < 6; ++maxOrder) {
192 DSFactory factory = new DSFactory(1, maxOrder);
193 double max = 0;
194 for (int i = 0; i < 10; i++) {
195 final double x = lo + random.nextDouble() * (hi - lo);
196 final DerivativeStructure dsX = factory.variable(0, x);
197 max = FastMath.max(max, FastMath.abs(dsX.getPartialDerivative(maxOrder) -
198 id.value(dsX).getPartialDerivative(maxOrder)));
199 Assert.assertEquals(dsX.getPartialDerivative(maxOrder),
200 id.value(dsX).getPartialDerivative(maxOrder),
201 epsilon[maxOrder]);
202 }
203
204
205
206 final DerivativeStructure dsLo = factory.variable(0, lo);
207 if (maxOrder == 0) {
208 Assert.assertTrue(Double.isInfinite(f.value(dsLo).getPartialDerivative(maxOrder)));
209 Assert.assertEquals(lo, id.value(dsLo).getPartialDerivative(maxOrder), epsilon[maxOrder]);
210 } else if (maxOrder == 1) {
211 Assert.assertTrue(Double.isInfinite(f.value(dsLo).getPartialDerivative(maxOrder)));
212 Assert.assertTrue(Double.isNaN(id.value(dsLo).getPartialDerivative(maxOrder)));
213 } else {
214 Assert.assertTrue(Double.isNaN(f.value(dsLo).getPartialDerivative(maxOrder)));
215 Assert.assertTrue(Double.isNaN(id.value(dsLo).getPartialDerivative(maxOrder)));
216 }
217
218 final DerivativeStructure dsHi = factory.variable(0, hi);
219 if (maxOrder == 0) {
220 Assert.assertTrue(Double.isInfinite(f.value(dsHi).getPartialDerivative(maxOrder)));
221 Assert.assertEquals(hi, id.value(dsHi).getPartialDerivative(maxOrder), epsilon[maxOrder]);
222 } else if (maxOrder == 1) {
223 Assert.assertTrue(Double.isInfinite(f.value(dsHi).getPartialDerivative(maxOrder)));
224 Assert.assertTrue(Double.isNaN(id.value(dsHi).getPartialDerivative(maxOrder)));
225 } else {
226 Assert.assertTrue(Double.isNaN(f.value(dsHi).getPartialDerivative(maxOrder)));
227 Assert.assertTrue(Double.isNaN(id.value(dsHi).getPartialDerivative(maxOrder)));
228 }
229
230 }
231 }
232 }