1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * https://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18 /* 19 * This is not the original file distributed by the Apache Software Foundation 20 * It has been modified by the Hipparchus project 21 */ 22 23 package org.hipparchus.analysis.function; 24 25 import java.util.Arrays; 26 27 import org.hipparchus.analysis.ParametricUnivariateFunction; 28 import org.hipparchus.analysis.differentiation.Derivative; 29 import org.hipparchus.analysis.differentiation.UnivariateDifferentiableFunction; 30 import org.hipparchus.exception.MathIllegalArgumentException; 31 import org.hipparchus.exception.NullArgumentException; 32 import org.hipparchus.util.FastMath; 33 import org.hipparchus.util.MathUtils; 34 35 /** 36 * <a href="http://en.wikipedia.org/wiki/Sigmoid_function"> 37 * Sigmoid</a> function. 38 * It is the inverse of the {@link Logit logit} function. 39 * A more flexible version, the generalised logistic, is implemented 40 * by the {@link Logistic} class. 41 * 42 */ 43 public class Sigmoid implements UnivariateDifferentiableFunction { 44 /** Lower asymptote. */ 45 private final double lo; 46 /** Higher asymptote. */ 47 private final double hi; 48 49 /** 50 * Usual sigmoid function, where the lower asymptote is 0 and the higher 51 * asymptote is 1. 52 */ 53 public Sigmoid() { 54 this(0, 1); 55 } 56 57 /** 58 * Sigmoid function. 59 * 60 * @param lo Lower asymptote. 61 * @param hi Higher asymptote. 62 */ 63 public Sigmoid(double lo, 64 double hi) { 65 this.lo = lo; 66 this.hi = hi; 67 } 68 69 /** {@inheritDoc} */ 70 @Override 71 public double value(double x) { 72 return value(x, lo, hi); 73 } 74 75 /** 76 * Parametric function where the input array contains the parameters of 77 * the {@link Sigmoid#Sigmoid(double,double) sigmoid function}, ordered 78 * as follows: 79 * <ul> 80 * <li>Lower asymptote</li> 81 * <li>Higher asymptote</li> 82 * </ul> 83 */ 84 public static class Parametric implements ParametricUnivariateFunction { 85 86 /** Empty constructor. 87 * <p> 88 * This constructor is not strictly necessary, but it prevents spurious 89 * javadoc warnings with JDK 18 and later. 90 * </p> 91 * @since 3.0 92 */ 93 public Parametric() { // NOPMD - unnecessary constructor added intentionally to make javadoc happy 94 // nothing to do 95 } 96 97 /** 98 * Computes the value of the sigmoid at {@code x}. 99 * 100 * @param x Value for which the function must be computed. 101 * @param param Values of lower asymptote and higher asymptote. 102 * @return the value of the function. 103 * @throws NullArgumentException if {@code param} is {@code null}. 104 * @throws MathIllegalArgumentException if the size of {@code param} is 105 * not 2. 106 */ 107 @Override 108 public double value(double x, double ... param) 109 throws MathIllegalArgumentException, NullArgumentException { 110 validateParameters(param); 111 return Sigmoid.value(x, param[0], param[1]); 112 } 113 114 /** 115 * Computes the value of the gradient at {@code x}. 116 * The components of the gradient vector are the partial 117 * derivatives of the function with respect to each of the 118 * <em>parameters</em> (lower asymptote and higher asymptote). 119 * 120 * @param x Value at which the gradient must be computed. 121 * @param param Values for lower asymptote and higher asymptote. 122 * @return the gradient vector at {@code x}. 123 * @throws NullArgumentException if {@code param} is {@code null}. 124 * @throws MathIllegalArgumentException if the size of {@code param} is 125 * not 2. 126 */ 127 @Override 128 public double[] gradient(double x, double ... param) 129 throws MathIllegalArgumentException, NullArgumentException { 130 validateParameters(param); 131 132 final double invExp1 = 1 / (1 + FastMath.exp(-x)); 133 134 return new double[] { 1 - invExp1, invExp1 }; 135 } 136 137 /** 138 * Validates parameters to ensure they are appropriate for the evaluation of 139 * the {@link #value(double,double[])} and {@link #gradient(double,double[])} 140 * methods. 141 * 142 * @param param Values for lower and higher asymptotes. 143 * @throws NullArgumentException if {@code param} is {@code null}. 144 * @throws MathIllegalArgumentException if the size of {@code param} is 145 * not 2. 146 */ 147 private void validateParameters(double[] param) 148 throws MathIllegalArgumentException, NullArgumentException { 149 MathUtils.checkNotNull(param); 150 MathUtils.checkDimension(param.length, 2); 151 } 152 } 153 154 /** 155 * @param x Value at which to compute the sigmoid. 156 * @param lo Lower asymptote. 157 * @param hi Higher asymptote. 158 * @return the value of the sigmoid function at {@code x}. 159 */ 160 private static double value(double x, 161 double lo, 162 double hi) { 163 return lo + (hi - lo) / (1 + FastMath.exp(-x)); 164 } 165 166 /** {@inheritDoc} 167 */ 168 @Override 169 public <T extends Derivative<T>> T value(T t) 170 throws MathIllegalArgumentException { 171 172 double[] f = new double[t.getOrder() + 1]; 173 final double exp = FastMath.exp(-t.getValue()); 174 if (Double.isInfinite(exp)) { 175 176 // special handling near lower boundary, to avoid NaN 177 f[0] = lo; 178 Arrays.fill(f, 1, f.length, 0.0); 179 180 } else { 181 182 // the nth order derivative of sigmoid has the form: 183 // dn(sigmoid(x)/dxn = P_n(exp(-x)) / (1+exp(-x))^(n+1) 184 // where P_n(t) is a degree n polynomial with normalized higher term 185 // P_0(t) = 1, P_1(t) = t, P_2(t) = t^2 - t, P_3(t) = t^3 - 4 t^2 + t... 186 // the general recurrence relation for P_n is: 187 // P_n(x) = n t P_(n-1)(t) - t (1 + t) P_(n-1)'(t) 188 final double[] p = new double[f.length]; 189 190 final double inv = 1 / (1 + exp); 191 double coeff = hi - lo; 192 for (int n = 0; n < f.length; ++n) { 193 194 // update and evaluate polynomial P_n(t) 195 double v = 0; 196 p[n] = 1; 197 for (int k = n; k >= 0; --k) { 198 v = v * exp + p[k]; 199 if (k > 1) { 200 p[k - 1] = (n - k + 2) * p[k - 2] - (k - 1) * p[k - 1]; 201 } else { 202 p[0] = 0; 203 } 204 } 205 206 coeff *= inv; 207 f[n] = coeff * v; 208 209 } 210 211 // fix function value 212 f[0] += lo; 213 214 } 215 216 return t.compose(f); 217 218 } 219 220 }