1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * https://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 /*
19 * This is not the original file distributed by the Apache Software Foundation
20 * It has been modified by the Hipparchus project
21 */
22
23 package org.hipparchus.analysis.function;
24
25 import org.hipparchus.analysis.ParametricUnivariateFunction;
26 import org.hipparchus.analysis.differentiation.Derivative;
27 import org.hipparchus.analysis.differentiation.UnivariateDifferentiableFunction;
28 import org.hipparchus.exception.MathIllegalArgumentException;
29 import org.hipparchus.exception.NullArgumentException;
30 import org.hipparchus.util.FastMath;
31 import org.hipparchus.util.MathUtils;
32
33 /**
34 * <a href="http://en.wikipedia.org/wiki/Logit">
35 * Logit</a> function.
36 * It is the inverse of the {@link Sigmoid sigmoid} function.
37 *
38 */
39 public class Logit implements UnivariateDifferentiableFunction {
40 /** Lower bound. */
41 private final double lo;
42 /** Higher bound. */
43 private final double hi;
44
45 /**
46 * Usual logit function, where the lower bound is 0 and the higher
47 * bound is 1.
48 */
49 public Logit() {
50 this(0, 1);
51 }
52
53 /**
54 * Logit function.
55 *
56 * @param lo Lower bound of the function domain.
57 * @param hi Higher bound of the function domain.
58 */
59 public Logit(double lo,
60 double hi) {
61 this.lo = lo;
62 this.hi = hi;
63 }
64
65 /** {@inheritDoc} */
66 @Override
67 public double value(double x)
68 throws MathIllegalArgumentException {
69 return value(x, lo, hi);
70 }
71
72 /**
73 * Parametric function where the input array contains the parameters of
74 * the logit function, ordered as follows:
75 * <ul>
76 * <li>Lower bound</li>
77 * <li>Higher bound</li>
78 * </ul>
79 */
80 public static class Parametric implements ParametricUnivariateFunction {
81
82 /** Empty constructor.
83 * <p>
84 * This constructor is not strictly necessary, but it prevents spurious
85 * javadoc warnings with JDK 18 and later.
86 * </p>
87 * @since 3.0
88 */
89 public Parametric() { // NOPMD - unnecessary constructor added intentionally to make javadoc happy
90 // nothing to do
91 }
92
93 /**
94 * Computes the value of the logit at {@code x}.
95 *
96 * @param x Value for which the function must be computed.
97 * @param param Values of lower bound and higher bounds.
98 * @return the value of the function.
99 * @throws NullArgumentException if {@code param} is {@code null}.
100 * @throws MathIllegalArgumentException if the size of {@code param} is
101 * not 2.
102 */
103 @Override
104 public double value(double x, double ... param)
105 throws MathIllegalArgumentException, NullArgumentException {
106 validateParameters(param);
107 return Logit.value(x, param[0], param[1]);
108 }
109
110 /**
111 * Computes the value of the gradient at {@code x}.
112 * The components of the gradient vector are the partial
113 * derivatives of the function with respect to each of the
114 * <em>parameters</em> (lower bound and higher bound).
115 *
116 * @param x Value at which the gradient must be computed.
117 * @param param Values for lower and higher bounds.
118 * @return the gradient vector at {@code x}.
119 * @throws NullArgumentException if {@code param} is {@code null}.
120 * @throws MathIllegalArgumentException if the size of {@code param} is
121 * not 2.
122 */
123 @Override
124 public double[] gradient(double x, double ... param)
125 throws MathIllegalArgumentException, NullArgumentException {
126 validateParameters(param);
127
128 final double lo = param[0];
129 final double hi = param[1];
130
131 return new double[] { 1 / (lo - x), 1 / (hi - x) };
132 }
133
134 /**
135 * Validates parameters to ensure they are appropriate for the evaluation of
136 * the {@link #value(double,double[])} and {@link #gradient(double,double[])}
137 * methods.
138 *
139 * @param param Values for lower and higher bounds.
140 * @throws NullArgumentException if {@code param} is {@code null}.
141 * @throws MathIllegalArgumentException if the size of {@code param} is
142 * not 2.
143 */
144 private void validateParameters(double[] param)
145 throws MathIllegalArgumentException, NullArgumentException {
146 MathUtils.checkNotNull(param);
147 MathUtils.checkDimension(param.length, 2);
148 }
149 }
150
151 /**
152 * @param x Value at which to compute the logit.
153 * @param lo Lower bound.
154 * @param hi Higher bound.
155 * @return the value of the logit function at {@code x}.
156 * @throws MathIllegalArgumentException if {@code x < lo} or {@code x > hi}.
157 */
158 private static double value(double x,
159 double lo,
160 double hi)
161 throws MathIllegalArgumentException {
162 MathUtils.checkRangeInclusive(x, lo, hi);
163 return FastMath.log((x - lo) / (hi - x));
164 }
165
166 /** {@inheritDoc}
167 * @exception MathIllegalArgumentException if parameter is outside of function domain
168 */
169 @Override
170 public <T extends Derivative<T>> T value(T t)
171 throws MathIllegalArgumentException {
172 final double x = t.getValue();
173 MathUtils.checkRangeInclusive(x, lo, hi);
174 double[] f = new double[t.getOrder() + 1];
175
176 // function value
177 f[0] = FastMath.log((x - lo) / (hi - x));
178
179 if (Double.isInfinite(f[0])) {
180
181 if (f.length > 1) {
182 f[1] = Double.POSITIVE_INFINITY;
183 }
184 // fill the array with infinities
185 // (for x close to lo the signs will flip between -inf and +inf,
186 // for x close to hi the signs will always be +inf)
187 // this is probably overkill, since the call to compose at the end
188 // of the method will transform most infinities into NaN ...
189 for (int i = 2; i < f.length; ++i) {
190 f[i] = f[i - 2];
191 }
192
193 } else {
194
195 // function derivatives
196 final double invL = 1.0 / (x - lo);
197 double xL = invL;
198 final double invH = 1.0 / (hi - x);
199 double xH = invH;
200 for (int i = 1; i < f.length; ++i) {
201 f[i] = xL + xH;
202 xL *= -i * invL;
203 xH *= i * invH;
204 }
205 }
206
207 return t.compose(f);
208 }
209 }