1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * https://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18 /* 19 * This is not the original file distributed by the Apache Software Foundation 20 * It has been modified by the Hipparchus project 21 */ 22 23 package org.hipparchus.analysis.function; 24 25 import org.hipparchus.analysis.ParametricUnivariateFunction; 26 import org.hipparchus.analysis.differentiation.Derivative; 27 import org.hipparchus.analysis.differentiation.UnivariateDifferentiableFunction; 28 import org.hipparchus.exception.LocalizedCoreFormats; 29 import org.hipparchus.exception.MathIllegalArgumentException; 30 import org.hipparchus.exception.NullArgumentException; 31 import org.hipparchus.util.FastMath; 32 import org.hipparchus.util.MathUtils; 33 34 /** 35 * <a href="http://en.wikipedia.org/wiki/Generalised_logistic_function"> 36 * Generalised logistic</a> function. 37 * 38 */ 39 public class Logistic implements UnivariateDifferentiableFunction { 40 /** Lower asymptote. */ 41 private final double a; 42 /** Upper asymptote. */ 43 private final double k; 44 /** Growth rate. */ 45 private final double b; 46 /** Parameter that affects near which asymptote maximum growth occurs. */ 47 private final double oneOverN; 48 /** Parameter that affects the position of the curve along the ordinate axis. */ 49 private final double q; 50 /** Abscissa of maximum growth. */ 51 private final double m; 52 53 /** Simple constructor. 54 * @param k If {@code b > 0}, value of the function for x going towards +∞. 55 * If {@code b < 0}, value of the function for x going towards -∞. 56 * @param m Abscissa of maximum growth. 57 * @param b Growth rate. 58 * @param q Parameter that affects the position of the curve along the 59 * ordinate axis. 60 * @param a If {@code b > 0}, value of the function for x going towards -∞. 61 * If {@code b < 0}, value of the function for x going towards +∞. 62 * @param n Parameter that affects near which asymptote the maximum 63 * growth occurs. 64 * @throws MathIllegalArgumentException if {@code n <= 0}. 65 */ 66 public Logistic(double k, 67 double m, 68 double b, 69 double q, 70 double a, 71 double n) 72 throws MathIllegalArgumentException { 73 if (n <= 0) { 74 throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL_BOUND_EXCLUDED, 75 n, 0); 76 } 77 78 this.k = k; 79 this.m = m; 80 this.b = b; 81 this.q = q; 82 this.a = a; 83 oneOverN = 1 / n; 84 } 85 86 /** {@inheritDoc} */ 87 @Override 88 public double value(double x) { 89 return value(m - x, k, b, q, a, oneOverN); 90 } 91 92 /** 93 * Parametric function where the input array contains the parameters of 94 * the {@link Logistic#Logistic(double,double,double,double,double,double) 95 * logistic function}, ordered as follows: 96 * <ul> 97 * <li>k</li> 98 * <li>m</li> 99 * <li>b</li> 100 * <li>q</li> 101 * <li>a</li> 102 * <li>n</li> 103 * </ul> 104 */ 105 public static class Parametric implements ParametricUnivariateFunction { 106 107 /** Empty constructor. 108 * <p> 109 * This constructor is not strictly necessary, but it prevents spurious 110 * javadoc warnings with JDK 18 and later. 111 * </p> 112 * @since 3.0 113 */ 114 public Parametric() { // NOPMD - unnecessary constructor added intentionally to make javadoc happy 115 // nothing to do 116 } 117 118 /** 119 * Computes the value of the sigmoid at {@code x}. 120 * 121 * @param x Value for which the function must be computed. 122 * @param param Values for {@code k}, {@code m}, {@code b}, {@code q}, 123 * {@code a} and {@code n}. 124 * @return the value of the function. 125 * @throws NullArgumentException if {@code param} is {@code null}. 126 * @throws MathIllegalArgumentException if the size of {@code param} is 127 * not 6. 128 * @throws MathIllegalArgumentException if {@code param[5] <= 0}. 129 */ 130 @Override 131 public double value(double x, double ... param) 132 throws MathIllegalArgumentException, NullArgumentException { 133 validateParameters(param); 134 return Logistic.value(param[1] - x, param[0], 135 param[2], param[3], 136 param[4], 1 / param[5]); 137 } 138 139 /** 140 * Computes the value of the gradient at {@code x}. 141 * The components of the gradient vector are the partial 142 * derivatives of the function with respect to each of the 143 * <em>parameters</em>. 144 * 145 * @param x Value at which the gradient must be computed. 146 * @param param Values for {@code k}, {@code m}, {@code b}, {@code q}, 147 * {@code a} and {@code n}. 148 * @return the gradient vector at {@code x}. 149 * @throws NullArgumentException if {@code param} is {@code null}. 150 * @throws MathIllegalArgumentException if the size of {@code param} is 151 * not 6. 152 * @throws MathIllegalArgumentException if {@code param[5] <= 0}. 153 */ 154 @Override 155 public double[] gradient(double x, double ... param) 156 throws MathIllegalArgumentException, NullArgumentException { 157 validateParameters(param); 158 159 final double b = param[2]; 160 final double q = param[3]; 161 162 final double mMinusX = param[1] - x; 163 final double oneOverN = 1 / param[5]; 164 final double exp = FastMath.exp(b * mMinusX); 165 final double qExp = q * exp; 166 final double qExp1 = qExp + 1; 167 final double factor1 = (param[0] - param[4]) * oneOverN / FastMath.pow(qExp1, oneOverN); 168 final double factor2 = -factor1 / qExp1; 169 170 // Components of the gradient. 171 final double gk = Logistic.value(mMinusX, 1, b, q, 0, oneOverN); 172 final double gm = factor2 * b * qExp; 173 final double gb = factor2 * mMinusX * qExp; 174 final double gq = factor2 * exp; 175 final double ga = Logistic.value(mMinusX, 0, b, q, 1, oneOverN); 176 final double gn = factor1 * FastMath.log(qExp1) * oneOverN; 177 178 return new double[] { gk, gm, gb, gq, ga, gn }; 179 } 180 181 /** 182 * Validates parameters to ensure they are appropriate for the evaluation of 183 * the {@link #value(double,double[])} and {@link #gradient(double,double[])} 184 * methods. 185 * 186 * @param param Values for {@code k}, {@code m}, {@code b}, {@code q}, 187 * {@code a} and {@code n}. 188 * @throws NullArgumentException if {@code param} is {@code null}. 189 * @throws MathIllegalArgumentException if the size of {@code param} is 190 * not 6. 191 * @throws MathIllegalArgumentException if {@code param[5] <= 0}. 192 */ 193 private void validateParameters(double[] param) 194 throws MathIllegalArgumentException, NullArgumentException { 195 MathUtils.checkNotNull(param); 196 MathUtils.checkDimension(param.length, 6); 197 if (param[5] <= 0) { 198 throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL_BOUND_EXCLUDED, 199 param[5], 0); 200 } 201 } 202 } 203 204 /** 205 * @param mMinusX {@code m - x}. 206 * @param k {@code k}. 207 * @param b {@code b}. 208 * @param q {@code q}. 209 * @param a {@code a}. 210 * @param oneOverN {@code 1 / n}. 211 * @return the value of the function. 212 */ 213 private static double value(double mMinusX, 214 double k, 215 double b, 216 double q, 217 double a, 218 double oneOverN) { 219 return a + (k - a) / FastMath.pow(1 + q * FastMath.exp(b * mMinusX), oneOverN); 220 } 221 222 /** {@inheritDoc} 223 */ 224 @Override 225 public <T extends Derivative<T>> T value(T t) { 226 return t.negate().add(m).multiply(b).exp().multiply(q).add(1).pow(oneOverN).reciprocal().multiply(k - a).add(a); 227 } 228 229 }