1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23 package org.hipparchus.analysis.function;
24
25 import org.hipparchus.analysis.UnivariateFunction;
26 import org.hipparchus.analysis.differentiation.DSFactory;
27 import org.hipparchus.analysis.differentiation.DerivativeStructure;
28 import org.hipparchus.exception.MathIllegalArgumentException;
29 import org.hipparchus.exception.NullArgumentException;
30 import org.junit.jupiter.api.Test;
31
32 import static org.junit.jupiter.api.Assertions.assertEquals;
33 import static org.junit.jupiter.api.Assertions.assertThrows;
34
35
36
37
38 class SigmoidTest {
39 private final double EPS = Math.ulp(1d);
40
41 @Test
42 void testSomeValues() {
43 final UnivariateFunction f = new Sigmoid();
44
45 assertEquals(0.5, f.value(0), EPS);
46 assertEquals(0, f.value(Double.NEGATIVE_INFINITY), EPS);
47 assertEquals(1, f.value(Double.POSITIVE_INFINITY), EPS);
48 }
49
50 @Test
51 void testDerivative() {
52 final Sigmoid f = new Sigmoid();
53 final DerivativeStructure f0 = f.value(new DSFactory(1, 1).variable(0, 0.0));
54
55 assertEquals(0.25, f0.getPartialDerivative(1), 0);
56 }
57
58 @Test
59 void testDerivativesHighOrder() {
60 DerivativeStructure s = new Sigmoid(1, 3).value(new DSFactory(1, 5).variable(0, 1.2));
61 assertEquals(2.5370495669980352859, s.getPartialDerivative(0), 5.0e-16);
62 assertEquals(0.35578888129361140441, s.getPartialDerivative(1), 6.0e-17);
63 assertEquals(-0.19107626464144938116, s.getPartialDerivative(2), 6.0e-17);
64 assertEquals(-0.02396830286286711696, s.getPartialDerivative(3), 4.0e-17);
65 assertEquals(0.21682059798981049049, s.getPartialDerivative(4), 3.0e-17);
66 assertEquals(-0.19186320234632658055, s.getPartialDerivative(5), 2.0e-16);
67 }
68
69 @Test
70 void testDerivativeLargeArguments() {
71 final Sigmoid f = new Sigmoid(1, 2);
72
73 DSFactory factory = new DSFactory(1, 1);
74 assertEquals(0, f.value(factory.variable(0, Double.NEGATIVE_INFINITY)).getPartialDerivative(1), 0);
75 assertEquals(0, f.value(factory.variable(0, -Double.MAX_VALUE)).getPartialDerivative(1), 0);
76 assertEquals(0, f.value(factory.variable(0, -1e50)).getPartialDerivative(1), 0);
77 assertEquals(0, f.value(factory.variable(0, -1e3)).getPartialDerivative(1), 0);
78 assertEquals(0, f.value(factory.variable(0, 1e3)).getPartialDerivative(1), 0);
79 assertEquals(0, f.value(factory.variable(0, 1e50)).getPartialDerivative(1), 0);
80 assertEquals(0, f.value(factory.variable(0, Double.MAX_VALUE)).getPartialDerivative(1), 0);
81 assertEquals(0, f.value(factory.variable(0, Double.POSITIVE_INFINITY)).getPartialDerivative(1), 0);
82 }
83
84 @Test
85 void testParametricUsage1() {
86 assertThrows(NullArgumentException.class, () -> {
87 final Sigmoid.Parametric g = new Sigmoid.Parametric();
88 g.value(0, null);
89 });
90 }
91
92 @Test
93 void testParametricUsage2() {
94 assertThrows(MathIllegalArgumentException.class, () -> {
95 final Sigmoid.Parametric g = new Sigmoid.Parametric();
96 g.value(0, new double[]{0});
97 });
98 }
99
100 @Test
101 void testParametricUsage3() {
102 assertThrows(NullArgumentException.class, () -> {
103 final Sigmoid.Parametric g = new Sigmoid.Parametric();
104 g.gradient(0, null);
105 });
106 }
107
108 @Test
109 void testParametricUsage4() {
110 assertThrows(MathIllegalArgumentException.class, () -> {
111 final Sigmoid.Parametric g = new Sigmoid.Parametric();
112 g.gradient(0, new double[]{0});
113 });
114 }
115
116 @Test
117 void testParametricValue() {
118 final double lo = 2;
119 final double hi = 3;
120 final Sigmoid f = new Sigmoid(lo, hi);
121
122 final Sigmoid.Parametric g = new Sigmoid.Parametric();
123 assertEquals(f.value(-1), g.value(-1, new double[] {lo, hi}), 0);
124 assertEquals(f.value(0), g.value(0, new double[] {lo, hi}), 0);
125 assertEquals(f.value(2), g.value(2, new double[] {lo, hi}), 0);
126 }
127 }