View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  /*
19   * This is not the original file distributed by the Apache Software Foundation
20   * It has been modified by the Hipparchus project
21   */
22  
23  package org.hipparchus.analysis.function;
24  
25  import java.util.Arrays;
26  
27  import org.hipparchus.analysis.ParametricUnivariateFunction;
28  import org.hipparchus.analysis.differentiation.Derivative;
29  import org.hipparchus.analysis.differentiation.UnivariateDifferentiableFunction;
30  import org.hipparchus.exception.MathIllegalArgumentException;
31  import org.hipparchus.exception.NullArgumentException;
32  import org.hipparchus.util.FastMath;
33  import org.hipparchus.util.MathUtils;
34  
35  /**
36   * <a href="http://en.wikipedia.org/wiki/Sigmoid_function">
37   *  Sigmoid</a> function.
38   * It is the inverse of the {@link Logit logit} function.
39   * A more flexible version, the generalised logistic, is implemented
40   * by the {@link Logistic} class.
41   *
42   */
43  public class Sigmoid implements UnivariateDifferentiableFunction {
44      /** Lower asymptote. */
45      private final double lo;
46      /** Higher asymptote. */
47      private final double hi;
48  
49      /**
50       * Usual sigmoid function, where the lower asymptote is 0 and the higher
51       * asymptote is 1.
52       */
53      public Sigmoid() {
54          this(0, 1);
55      }
56  
57      /**
58       * Sigmoid function.
59       *
60       * @param lo Lower asymptote.
61       * @param hi Higher asymptote.
62       */
63      public Sigmoid(double lo,
64                     double hi) {
65          this.lo = lo;
66          this.hi = hi;
67      }
68  
69      /** {@inheritDoc} */
70      @Override
71      public double value(double x) {
72          return value(x, lo, hi);
73      }
74  
75      /**
76       * Parametric function where the input array contains the parameters of
77       * the {@link Sigmoid#Sigmoid(double,double) sigmoid function}, ordered
78       * as follows:
79       * <ul>
80       *  <li>Lower asymptote</li>
81       *  <li>Higher asymptote</li>
82       * </ul>
83       */
84      public static class Parametric implements ParametricUnivariateFunction {
85  
86          /** Empty constructor.
87           * <p>
88           * This constructor is not strictly necessary, but it prevents spurious
89           * javadoc warnings with JDK 18 and later.
90           * </p>
91           * @since 3.0
92           */
93          public Parametric() { // NOPMD - unnecessary constructor added intentionally to make javadoc happy
94              // nothing to do
95          }
96  
97          /**
98           * Computes the value of the sigmoid at {@code x}.
99           *
100          * @param x Value for which the function must be computed.
101          * @param param Values of lower asymptote and higher asymptote.
102          * @return the value of the function.
103          * @throws NullArgumentException if {@code param} is {@code null}.
104          * @throws MathIllegalArgumentException if the size of {@code param} is
105          * not 2.
106          */
107         @Override
108         public double value(double x, double ... param)
109             throws MathIllegalArgumentException, NullArgumentException {
110             validateParameters(param);
111             return Sigmoid.value(x, param[0], param[1]);
112         }
113 
114         /**
115          * Computes the value of the gradient at {@code x}.
116          * The components of the gradient vector are the partial
117          * derivatives of the function with respect to each of the
118          * <em>parameters</em> (lower asymptote and higher asymptote).
119          *
120          * @param x Value at which the gradient must be computed.
121          * @param param Values for lower asymptote and higher asymptote.
122          * @return the gradient vector at {@code x}.
123          * @throws NullArgumentException if {@code param} is {@code null}.
124          * @throws MathIllegalArgumentException if the size of {@code param} is
125          * not 2.
126          */
127         @Override
128         public double[] gradient(double x, double ... param)
129             throws MathIllegalArgumentException, NullArgumentException {
130             validateParameters(param);
131 
132             final double invExp1 = 1 / (1 + FastMath.exp(-x));
133 
134             return new double[] { 1 - invExp1, invExp1 };
135         }
136 
137         /**
138          * Validates parameters to ensure they are appropriate for the evaluation of
139          * the {@link #value(double,double[])} and {@link #gradient(double,double[])}
140          * methods.
141          *
142          * @param param Values for lower and higher asymptotes.
143          * @throws NullArgumentException if {@code param} is {@code null}.
144          * @throws MathIllegalArgumentException if the size of {@code param} is
145          * not 2.
146          */
147         private void validateParameters(double[] param)
148             throws MathIllegalArgumentException, NullArgumentException {
149             MathUtils.checkNotNull(param);
150             MathUtils.checkDimension(param.length, 2);
151         }
152     }
153 
154     /**
155      * @param x Value at which to compute the sigmoid.
156      * @param lo Lower asymptote.
157      * @param hi Higher asymptote.
158      * @return the value of the sigmoid function at {@code x}.
159      */
160     private static double value(double x,
161                                 double lo,
162                                 double hi) {
163         return lo + (hi - lo) / (1 + FastMath.exp(-x));
164     }
165 
166     /** {@inheritDoc}
167      */
168     @Override
169     public <T extends Derivative<T>> T value(T t)
170         throws MathIllegalArgumentException {
171 
172         double[] f = new double[t.getOrder() + 1];
173         final double exp = FastMath.exp(-t.getValue());
174         if (Double.isInfinite(exp)) {
175 
176             // special handling near lower boundary, to avoid NaN
177             f[0] = lo;
178             Arrays.fill(f, 1, f.length, 0.0);
179 
180         } else {
181 
182             // the nth order derivative of sigmoid has the form:
183             // dn(sigmoid(x)/dxn = P_n(exp(-x)) / (1+exp(-x))^(n+1)
184             // where P_n(t) is a degree n polynomial with normalized higher term
185             // P_0(t) = 1, P_1(t) = t, P_2(t) = t^2 - t, P_3(t) = t^3 - 4 t^2 + t...
186             // the general recurrence relation for P_n is:
187             // P_n(x) = n t P_(n-1)(t) - t (1 + t) P_(n-1)'(t)
188             final double[] p = new double[f.length];
189 
190             final double inv   = 1 / (1 + exp);
191             double coeff = hi - lo;
192             for (int n = 0; n < f.length; ++n) {
193 
194                 // update and evaluate polynomial P_n(t)
195                 double v = 0;
196                 p[n] = 1;
197                 for (int k = n; k >= 0; --k) {
198                     v = v * exp + p[k];
199                     if (k > 1) {
200                         p[k - 1] = (n - k + 2) * p[k - 2] - (k - 1) * p[k - 1];
201                     } else {
202                         p[0] = 0;
203                     }
204                 }
205 
206                 coeff *= inv;
207                 f[n]   = coeff * v;
208 
209             }
210 
211             // fix function value
212             f[0] += lo;
213 
214         }
215 
216         return t.compose(f);
217 
218     }
219 
220 }