1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * https://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18 /* 19 * This is not the original file distributed by the Apache Software Foundation 20 * It has been modified by the Hipparchus project 21 */ 22 23 package org.hipparchus.analysis.function; 24 25 import org.hipparchus.analysis.ParametricUnivariateFunction; 26 import org.hipparchus.analysis.differentiation.Derivative; 27 import org.hipparchus.analysis.differentiation.UnivariateDifferentiableFunction; 28 import org.hipparchus.exception.MathIllegalArgumentException; 29 import org.hipparchus.exception.NullArgumentException; 30 import org.hipparchus.util.FastMath; 31 import org.hipparchus.util.MathUtils; 32 33 /** 34 * <a href="http://en.wikipedia.org/wiki/Logit"> 35 * Logit</a> function. 36 * It is the inverse of the {@link Sigmoid sigmoid} function. 37 * 38 */ 39 public class Logit implements UnivariateDifferentiableFunction { 40 /** Lower bound. */ 41 private final double lo; 42 /** Higher bound. */ 43 private final double hi; 44 45 /** 46 * Usual logit function, where the lower bound is 0 and the higher 47 * bound is 1. 48 */ 49 public Logit() { 50 this(0, 1); 51 } 52 53 /** 54 * Logit function. 55 * 56 * @param lo Lower bound of the function domain. 57 * @param hi Higher bound of the function domain. 58 */ 59 public Logit(double lo, 60 double hi) { 61 this.lo = lo; 62 this.hi = hi; 63 } 64 65 /** {@inheritDoc} */ 66 @Override 67 public double value(double x) 68 throws MathIllegalArgumentException { 69 return value(x, lo, hi); 70 } 71 72 /** 73 * Parametric function where the input array contains the parameters of 74 * the logit function, ordered as follows: 75 * <ul> 76 * <li>Lower bound</li> 77 * <li>Higher bound</li> 78 * </ul> 79 */ 80 public static class Parametric implements ParametricUnivariateFunction { 81 82 /** Empty constructor. 83 * <p> 84 * This constructor is not strictly necessary, but it prevents spurious 85 * javadoc warnings with JDK 18 and later. 86 * </p> 87 * @since 3.0 88 */ 89 public Parametric() { // NOPMD - unnecessary constructor added intentionally to make javadoc happy 90 // nothing to do 91 } 92 93 /** 94 * Computes the value of the logit at {@code x}. 95 * 96 * @param x Value for which the function must be computed. 97 * @param param Values of lower bound and higher bounds. 98 * @return the value of the function. 99 * @throws NullArgumentException if {@code param} is {@code null}. 100 * @throws MathIllegalArgumentException if the size of {@code param} is 101 * not 2. 102 */ 103 @Override 104 public double value(double x, double ... param) 105 throws MathIllegalArgumentException, NullArgumentException { 106 validateParameters(param); 107 return Logit.value(x, param[0], param[1]); 108 } 109 110 /** 111 * Computes the value of the gradient at {@code x}. 112 * The components of the gradient vector are the partial 113 * derivatives of the function with respect to each of the 114 * <em>parameters</em> (lower bound and higher bound). 115 * 116 * @param x Value at which the gradient must be computed. 117 * @param param Values for lower and higher bounds. 118 * @return the gradient vector at {@code x}. 119 * @throws NullArgumentException if {@code param} is {@code null}. 120 * @throws MathIllegalArgumentException if the size of {@code param} is 121 * not 2. 122 */ 123 @Override 124 public double[] gradient(double x, double ... param) 125 throws MathIllegalArgumentException, NullArgumentException { 126 validateParameters(param); 127 128 final double lo = param[0]; 129 final double hi = param[1]; 130 131 return new double[] { 1 / (lo - x), 1 / (hi - x) }; 132 } 133 134 /** 135 * Validates parameters to ensure they are appropriate for the evaluation of 136 * the {@link #value(double,double[])} and {@link #gradient(double,double[])} 137 * methods. 138 * 139 * @param param Values for lower and higher bounds. 140 * @throws NullArgumentException if {@code param} is {@code null}. 141 * @throws MathIllegalArgumentException if the size of {@code param} is 142 * not 2. 143 */ 144 private void validateParameters(double[] param) 145 throws MathIllegalArgumentException, NullArgumentException { 146 MathUtils.checkNotNull(param); 147 MathUtils.checkDimension(param.length, 2); 148 } 149 } 150 151 /** 152 * @param x Value at which to compute the logit. 153 * @param lo Lower bound. 154 * @param hi Higher bound. 155 * @return the value of the logit function at {@code x}. 156 * @throws MathIllegalArgumentException if {@code x < lo} or {@code x > hi}. 157 */ 158 private static double value(double x, 159 double lo, 160 double hi) 161 throws MathIllegalArgumentException { 162 MathUtils.checkRangeInclusive(x, lo, hi); 163 return FastMath.log((x - lo) / (hi - x)); 164 } 165 166 /** {@inheritDoc} 167 * @exception MathIllegalArgumentException if parameter is outside of function domain 168 */ 169 @Override 170 public <T extends Derivative<T>> T value(T t) 171 throws MathIllegalArgumentException { 172 final double x = t.getValue(); 173 MathUtils.checkRangeInclusive(x, lo, hi); 174 double[] f = new double[t.getOrder() + 1]; 175 176 // function value 177 f[0] = FastMath.log((x - lo) / (hi - x)); 178 179 if (Double.isInfinite(f[0])) { 180 181 if (f.length > 1) { 182 f[1] = Double.POSITIVE_INFINITY; 183 } 184 // fill the array with infinities 185 // (for x close to lo the signs will flip between -inf and +inf, 186 // for x close to hi the signs will always be +inf) 187 // this is probably overkill, since the call to compose at the end 188 // of the method will transform most infinities into NaN ... 189 for (int i = 2; i < f.length; ++i) { 190 f[i] = f[i - 2]; 191 } 192 193 } else { 194 195 // function derivatives 196 final double invL = 1.0 / (x - lo); 197 double xL = invL; 198 final double invH = 1.0 / (hi - x); 199 double xH = invH; 200 for (int i = 1; i < f.length; ++i) { 201 f[i] = xL + xH; 202 xL *= -i * invL; 203 xH *= i * invH; 204 } 205 } 206 207 return t.compose(f); 208 } 209 }