1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * https://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 /*
19 * This is not the original file distributed by the Apache Software Foundation
20 * It has been modified by the Hipparchus project
21 */
22
23 package org.hipparchus.analysis.function;
24
25 import org.hipparchus.analysis.ParametricUnivariateFunction;
26 import org.hipparchus.analysis.differentiation.Derivative;
27 import org.hipparchus.analysis.differentiation.UnivariateDifferentiableFunction;
28 import org.hipparchus.exception.LocalizedCoreFormats;
29 import org.hipparchus.exception.MathIllegalArgumentException;
30 import org.hipparchus.exception.NullArgumentException;
31 import org.hipparchus.util.FastMath;
32 import org.hipparchus.util.MathUtils;
33
34 /**
35 * <a href="http://en.wikipedia.org/wiki/Generalised_logistic_function">
36 * Generalised logistic</a> function.
37 *
38 */
39 public class Logistic implements UnivariateDifferentiableFunction {
40 /** Lower asymptote. */
41 private final double a;
42 /** Upper asymptote. */
43 private final double k;
44 /** Growth rate. */
45 private final double b;
46 /** Parameter that affects near which asymptote maximum growth occurs. */
47 private final double oneOverN;
48 /** Parameter that affects the position of the curve along the ordinate axis. */
49 private final double q;
50 /** Abscissa of maximum growth. */
51 private final double m;
52
53 /** Simple constructor.
54 * @param k If {@code b > 0}, value of the function for x going towards +∞.
55 * If {@code b < 0}, value of the function for x going towards -∞.
56 * @param m Abscissa of maximum growth.
57 * @param b Growth rate.
58 * @param q Parameter that affects the position of the curve along the
59 * ordinate axis.
60 * @param a If {@code b > 0}, value of the function for x going towards -∞.
61 * If {@code b < 0}, value of the function for x going towards +∞.
62 * @param n Parameter that affects near which asymptote the maximum
63 * growth occurs.
64 * @throws MathIllegalArgumentException if {@code n <= 0}.
65 */
66 public Logistic(double k,
67 double m,
68 double b,
69 double q,
70 double a,
71 double n)
72 throws MathIllegalArgumentException {
73 if (n <= 0) {
74 throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL_BOUND_EXCLUDED,
75 n, 0);
76 }
77
78 this.k = k;
79 this.m = m;
80 this.b = b;
81 this.q = q;
82 this.a = a;
83 oneOverN = 1 / n;
84 }
85
86 /** {@inheritDoc} */
87 @Override
88 public double value(double x) {
89 return value(m - x, k, b, q, a, oneOverN);
90 }
91
92 /**
93 * Parametric function where the input array contains the parameters of
94 * the {@link Logistic#Logistic(double,double,double,double,double,double)
95 * logistic function}, ordered as follows:
96 * <ul>
97 * <li>k</li>
98 * <li>m</li>
99 * <li>b</li>
100 * <li>q</li>
101 * <li>a</li>
102 * <li>n</li>
103 * </ul>
104 */
105 public static class Parametric implements ParametricUnivariateFunction {
106
107 /** Empty constructor.
108 * <p>
109 * This constructor is not strictly necessary, but it prevents spurious
110 * javadoc warnings with JDK 18 and later.
111 * </p>
112 * @since 3.0
113 */
114 public Parametric() { // NOPMD - unnecessary constructor added intentionally to make javadoc happy
115 // nothing to do
116 }
117
118 /**
119 * Computes the value of the sigmoid at {@code x}.
120 *
121 * @param x Value for which the function must be computed.
122 * @param param Values for {@code k}, {@code m}, {@code b}, {@code q},
123 * {@code a} and {@code n}.
124 * @return the value of the function.
125 * @throws NullArgumentException if {@code param} is {@code null}.
126 * @throws MathIllegalArgumentException if the size of {@code param} is
127 * not 6.
128 * @throws MathIllegalArgumentException if {@code param[5] <= 0}.
129 */
130 @Override
131 public double value(double x, double ... param)
132 throws MathIllegalArgumentException, NullArgumentException {
133 validateParameters(param);
134 return Logistic.value(param[1] - x, param[0],
135 param[2], param[3],
136 param[4], 1 / param[5]);
137 }
138
139 /**
140 * Computes the value of the gradient at {@code x}.
141 * The components of the gradient vector are the partial
142 * derivatives of the function with respect to each of the
143 * <em>parameters</em>.
144 *
145 * @param x Value at which the gradient must be computed.
146 * @param param Values for {@code k}, {@code m}, {@code b}, {@code q},
147 * {@code a} and {@code n}.
148 * @return the gradient vector at {@code x}.
149 * @throws NullArgumentException if {@code param} is {@code null}.
150 * @throws MathIllegalArgumentException if the size of {@code param} is
151 * not 6.
152 * @throws MathIllegalArgumentException if {@code param[5] <= 0}.
153 */
154 @Override
155 public double[] gradient(double x, double ... param)
156 throws MathIllegalArgumentException, NullArgumentException {
157 validateParameters(param);
158
159 final double b = param[2];
160 final double q = param[3];
161
162 final double mMinusX = param[1] - x;
163 final double oneOverN = 1 / param[5];
164 final double exp = FastMath.exp(b * mMinusX);
165 final double qExp = q * exp;
166 final double qExp1 = qExp + 1;
167 final double factor1 = (param[0] - param[4]) * oneOverN / FastMath.pow(qExp1, oneOverN);
168 final double factor2 = -factor1 / qExp1;
169
170 // Components of the gradient.
171 final double gk = Logistic.value(mMinusX, 1, b, q, 0, oneOverN);
172 final double gm = factor2 * b * qExp;
173 final double gb = factor2 * mMinusX * qExp;
174 final double gq = factor2 * exp;
175 final double ga = Logistic.value(mMinusX, 0, b, q, 1, oneOverN);
176 final double gn = factor1 * FastMath.log(qExp1) * oneOverN;
177
178 return new double[] { gk, gm, gb, gq, ga, gn };
179 }
180
181 /**
182 * Validates parameters to ensure they are appropriate for the evaluation of
183 * the {@link #value(double,double[])} and {@link #gradient(double,double[])}
184 * methods.
185 *
186 * @param param Values for {@code k}, {@code m}, {@code b}, {@code q},
187 * {@code a} and {@code n}.
188 * @throws NullArgumentException if {@code param} is {@code null}.
189 * @throws MathIllegalArgumentException if the size of {@code param} is
190 * not 6.
191 * @throws MathIllegalArgumentException if {@code param[5] <= 0}.
192 */
193 private void validateParameters(double[] param)
194 throws MathIllegalArgumentException, NullArgumentException {
195 MathUtils.checkNotNull(param);
196 MathUtils.checkDimension(param.length, 6);
197 if (param[5] <= 0) {
198 throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL_BOUND_EXCLUDED,
199 param[5], 0);
200 }
201 }
202 }
203
204 /**
205 * @param mMinusX {@code m - x}.
206 * @param k {@code k}.
207 * @param b {@code b}.
208 * @param q {@code q}.
209 * @param a {@code a}.
210 * @param oneOverN {@code 1 / n}.
211 * @return the value of the function.
212 */
213 private static double value(double mMinusX,
214 double k,
215 double b,
216 double q,
217 double a,
218 double oneOverN) {
219 return a + (k - a) / FastMath.pow(1 + q * FastMath.exp(b * mMinusX), oneOverN);
220 }
221
222 /** {@inheritDoc}
223 */
224 @Override
225 public <T extends Derivative<T>> T value(T t) {
226 return t.negate().add(m).multiply(b).exp().multiply(q).add(1).pow(oneOverN).reciprocal().multiply(k - a).add(a);
227 }
228
229 }