1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * https://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 /*
19 * This is not the original file distributed by the Apache Software Foundation
20 * It has been modified by the Hipparchus project
21 */
22
23 package org.hipparchus.analysis.function;
24
25 import java.util.Arrays;
26
27 import org.hipparchus.analysis.ParametricUnivariateFunction;
28 import org.hipparchus.analysis.differentiation.Derivative;
29 import org.hipparchus.analysis.differentiation.UnivariateDifferentiableFunction;
30 import org.hipparchus.exception.MathIllegalArgumentException;
31 import org.hipparchus.exception.NullArgumentException;
32 import org.hipparchus.util.FastMath;
33 import org.hipparchus.util.MathUtils;
34
35 /**
36 * <a href="http://en.wikipedia.org/wiki/Sigmoid_function">
37 * Sigmoid</a> function.
38 * It is the inverse of the {@link Logit logit} function.
39 * A more flexible version, the generalised logistic, is implemented
40 * by the {@link Logistic} class.
41 *
42 */
43 public class Sigmoid implements UnivariateDifferentiableFunction {
44 /** Lower asymptote. */
45 private final double lo;
46 /** Higher asymptote. */
47 private final double hi;
48
49 /**
50 * Usual sigmoid function, where the lower asymptote is 0 and the higher
51 * asymptote is 1.
52 */
53 public Sigmoid() {
54 this(0, 1);
55 }
56
57 /**
58 * Sigmoid function.
59 *
60 * @param lo Lower asymptote.
61 * @param hi Higher asymptote.
62 */
63 public Sigmoid(double lo,
64 double hi) {
65 this.lo = lo;
66 this.hi = hi;
67 }
68
69 /** {@inheritDoc} */
70 @Override
71 public double value(double x) {
72 return value(x, lo, hi);
73 }
74
75 /**
76 * Parametric function where the input array contains the parameters of
77 * the {@link Sigmoid#Sigmoid(double,double) sigmoid function}, ordered
78 * as follows:
79 * <ul>
80 * <li>Lower asymptote</li>
81 * <li>Higher asymptote</li>
82 * </ul>
83 */
84 public static class Parametric implements ParametricUnivariateFunction {
85
86 /** Empty constructor.
87 * <p>
88 * This constructor is not strictly necessary, but it prevents spurious
89 * javadoc warnings with JDK 18 and later.
90 * </p>
91 * @since 3.0
92 */
93 public Parametric() { // NOPMD - unnecessary constructor added intentionally to make javadoc happy
94 // nothing to do
95 }
96
97 /**
98 * Computes the value of the sigmoid at {@code x}.
99 *
100 * @param x Value for which the function must be computed.
101 * @param param Values of lower asymptote and higher asymptote.
102 * @return the value of the function.
103 * @throws NullArgumentException if {@code param} is {@code null}.
104 * @throws MathIllegalArgumentException if the size of {@code param} is
105 * not 2.
106 */
107 @Override
108 public double value(double x, double ... param)
109 throws MathIllegalArgumentException, NullArgumentException {
110 validateParameters(param);
111 return Sigmoid.value(x, param[0], param[1]);
112 }
113
114 /**
115 * Computes the value of the gradient at {@code x}.
116 * The components of the gradient vector are the partial
117 * derivatives of the function with respect to each of the
118 * <em>parameters</em> (lower asymptote and higher asymptote).
119 *
120 * @param x Value at which the gradient must be computed.
121 * @param param Values for lower asymptote and higher asymptote.
122 * @return the gradient vector at {@code x}.
123 * @throws NullArgumentException if {@code param} is {@code null}.
124 * @throws MathIllegalArgumentException if the size of {@code param} is
125 * not 2.
126 */
127 @Override
128 public double[] gradient(double x, double ... param)
129 throws MathIllegalArgumentException, NullArgumentException {
130 validateParameters(param);
131
132 final double invExp1 = 1 / (1 + FastMath.exp(-x));
133
134 return new double[] { 1 - invExp1, invExp1 };
135 }
136
137 /**
138 * Validates parameters to ensure they are appropriate for the evaluation of
139 * the {@link #value(double,double[])} and {@link #gradient(double,double[])}
140 * methods.
141 *
142 * @param param Values for lower and higher asymptotes.
143 * @throws NullArgumentException if {@code param} is {@code null}.
144 * @throws MathIllegalArgumentException if the size of {@code param} is
145 * not 2.
146 */
147 private void validateParameters(double[] param)
148 throws MathIllegalArgumentException, NullArgumentException {
149 MathUtils.checkNotNull(param);
150 MathUtils.checkDimension(param.length, 2);
151 }
152 }
153
154 /**
155 * @param x Value at which to compute the sigmoid.
156 * @param lo Lower asymptote.
157 * @param hi Higher asymptote.
158 * @return the value of the sigmoid function at {@code x}.
159 */
160 private static double value(double x,
161 double lo,
162 double hi) {
163 return lo + (hi - lo) / (1 + FastMath.exp(-x));
164 }
165
166 /** {@inheritDoc}
167 */
168 @Override
169 public <T extends Derivative<T>> T value(T t)
170 throws MathIllegalArgumentException {
171
172 double[] f = new double[t.getOrder() + 1];
173 final double exp = FastMath.exp(-t.getValue());
174 if (Double.isInfinite(exp)) {
175
176 // special handling near lower boundary, to avoid NaN
177 f[0] = lo;
178 Arrays.fill(f, 1, f.length, 0.0);
179
180 } else {
181
182 // the nth order derivative of sigmoid has the form:
183 // dn(sigmoid(x)/dxn = P_n(exp(-x)) / (1+exp(-x))^(n+1)
184 // where P_n(t) is a degree n polynomial with normalized higher term
185 // P_0(t) = 1, P_1(t) = t, P_2(t) = t^2 - t, P_3(t) = t^3 - 4 t^2 + t...
186 // the general recurrence relation for P_n is:
187 // P_n(x) = n t P_(n-1)(t) - t (1 + t) P_(n-1)'(t)
188 final double[] p = new double[f.length];
189
190 final double inv = 1 / (1 + exp);
191 double coeff = hi - lo;
192 for (int n = 0; n < f.length; ++n) {
193
194 // update and evaluate polynomial P_n(t)
195 double v = 0;
196 p[n] = 1;
197 for (int k = n; k >= 0; --k) {
198 v = v * exp + p[k];
199 if (k > 1) {
200 p[k - 1] = (n - k + 2) * p[k - 2] - (k - 1) * p[k - 1];
201 } else {
202 p[0] = 0;
203 }
204 }
205
206 coeff *= inv;
207 f[n] = coeff * v;
208
209 }
210
211 // fix function value
212 f[0] += lo;
213
214 }
215
216 return t.compose(f);
217
218 }
219
220 }