1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23 package org.hipparchus.analysis.function;
24
25 import org.hipparchus.analysis.UnivariateFunction;
26 import org.hipparchus.analysis.differentiation.DSFactory;
27 import org.hipparchus.analysis.differentiation.DerivativeStructure;
28 import org.hipparchus.exception.MathIllegalArgumentException;
29 import org.hipparchus.exception.NullArgumentException;
30 import org.junit.Assert;
31 import org.junit.Test;
32
33
34
35
36 public class SigmoidTest {
37 private final double EPS = Math.ulp(1d);
38
39 @Test
40 public void testSomeValues() {
41 final UnivariateFunction f = new Sigmoid();
42
43 Assert.assertEquals(0.5, f.value(0), EPS);
44 Assert.assertEquals(0, f.value(Double.NEGATIVE_INFINITY), EPS);
45 Assert.assertEquals(1, f.value(Double.POSITIVE_INFINITY), EPS);
46 }
47
48 @Test
49 public void testDerivative() {
50 final Sigmoid f = new Sigmoid();
51 final DerivativeStructure f0 = f.value(new DSFactory(1, 1).variable(0, 0.0));
52
53 Assert.assertEquals(0.25, f0.getPartialDerivative(1), 0);
54 }
55
56 @Test
57 public void testDerivativesHighOrder() {
58 DerivativeStructure s = new Sigmoid(1, 3).value(new DSFactory(1, 5).variable(0, 1.2));
59 Assert.assertEquals(2.5370495669980352859, s.getPartialDerivative(0), 5.0e-16);
60 Assert.assertEquals(0.35578888129361140441, s.getPartialDerivative(1), 6.0e-17);
61 Assert.assertEquals(-0.19107626464144938116, s.getPartialDerivative(2), 6.0e-17);
62 Assert.assertEquals(-0.02396830286286711696, s.getPartialDerivative(3), 4.0e-17);
63 Assert.assertEquals(0.21682059798981049049, s.getPartialDerivative(4), 3.0e-17);
64 Assert.assertEquals(-0.19186320234632658055, s.getPartialDerivative(5), 2.0e-16);
65 }
66
67 @Test
68 public void testDerivativeLargeArguments() {
69 final Sigmoid f = new Sigmoid(1, 2);
70
71 DSFactory factory = new DSFactory(1, 1);
72 Assert.assertEquals(0, f.value(factory.variable(0, Double.NEGATIVE_INFINITY)).getPartialDerivative(1), 0);
73 Assert.assertEquals(0, f.value(factory.variable(0, -Double.MAX_VALUE)).getPartialDerivative(1), 0);
74 Assert.assertEquals(0, f.value(factory.variable(0, -1e50)).getPartialDerivative(1), 0);
75 Assert.assertEquals(0, f.value(factory.variable(0, -1e3)).getPartialDerivative(1), 0);
76 Assert.assertEquals(0, f.value(factory.variable(0, 1e3)).getPartialDerivative(1), 0);
77 Assert.assertEquals(0, f.value(factory.variable(0, 1e50)).getPartialDerivative(1), 0);
78 Assert.assertEquals(0, f.value(factory.variable(0, Double.MAX_VALUE)).getPartialDerivative(1), 0);
79 Assert.assertEquals(0, f.value(factory.variable(0, Double.POSITIVE_INFINITY)).getPartialDerivative(1), 0);
80 }
81
82 @Test(expected=NullArgumentException.class)
83 public void testParametricUsage1() {
84 final Sigmoid.Parametric g = new Sigmoid.Parametric();
85 g.value(0, null);
86 }
87
88 @Test(expected=MathIllegalArgumentException.class)
89 public void testParametricUsage2() {
90 final Sigmoid.Parametric g = new Sigmoid.Parametric();
91 g.value(0, new double[] {0});
92 }
93
94 @Test(expected=NullArgumentException.class)
95 public void testParametricUsage3() {
96 final Sigmoid.Parametric g = new Sigmoid.Parametric();
97 g.gradient(0, null);
98 }
99
100 @Test(expected=MathIllegalArgumentException.class)
101 public void testParametricUsage4() {
102 final Sigmoid.Parametric g = new Sigmoid.Parametric();
103 g.gradient(0, new double[] {0});
104 }
105
106 @Test
107 public void testParametricValue() {
108 final double lo = 2;
109 final double hi = 3;
110 final Sigmoid f = new Sigmoid(lo, hi);
111
112 final Sigmoid.Parametric g = new Sigmoid.Parametric();
113 Assert.assertEquals(f.value(-1), g.value(-1, new double[] {lo, hi}), 0);
114 Assert.assertEquals(f.value(0), g.value(0, new double[] {lo, hi}), 0);
115 Assert.assertEquals(f.value(2), g.value(2, new double[] {lo, hi}), 0);
116 }
117 }