Logistic.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      https://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. /*
  18.  * This is not the original file distributed by the Apache Software Foundation
  19.  * It has been modified by the Hipparchus project
  20.  */

  21. package org.hipparchus.analysis.function;

  22. import org.hipparchus.analysis.ParametricUnivariateFunction;
  23. import org.hipparchus.analysis.differentiation.Derivative;
  24. import org.hipparchus.analysis.differentiation.UnivariateDifferentiableFunction;
  25. import org.hipparchus.exception.LocalizedCoreFormats;
  26. import org.hipparchus.exception.MathIllegalArgumentException;
  27. import org.hipparchus.exception.NullArgumentException;
  28. import org.hipparchus.util.FastMath;
  29. import org.hipparchus.util.MathUtils;

  30. /**
  31.  * <a href="http://en.wikipedia.org/wiki/Generalised_logistic_function">
  32.  *  Generalised logistic</a> function.
  33.  *
  34.  */
  35. public class Logistic implements UnivariateDifferentiableFunction {
  36.     /** Lower asymptote. */
  37.     private final double a;
  38.     /** Upper asymptote. */
  39.     private final double k;
  40.     /** Growth rate. */
  41.     private final double b;
  42.     /** Parameter that affects near which asymptote maximum growth occurs. */
  43.     private final double oneOverN;
  44.     /** Parameter that affects the position of the curve along the ordinate axis. */
  45.     private final double q;
  46.     /** Abscissa of maximum growth. */
  47.     private final double m;

  48.     /** Simple constructor.
  49.      * @param k If {@code b > 0}, value of the function for x going towards +&infin;.
  50.      * If {@code b < 0}, value of the function for x going towards -&infin;.
  51.      * @param m Abscissa of maximum growth.
  52.      * @param b Growth rate.
  53.      * @param q Parameter that affects the position of the curve along the
  54.      * ordinate axis.
  55.      * @param a If {@code b > 0}, value of the function for x going towards -&infin;.
  56.      * If {@code b < 0}, value of the function for x going towards +&infin;.
  57.      * @param n Parameter that affects near which asymptote the maximum
  58.      * growth occurs.
  59.      * @throws MathIllegalArgumentException if {@code n <= 0}.
  60.      */
  61.     public Logistic(double k,
  62.                     double m,
  63.                     double b,
  64.                     double q,
  65.                     double a,
  66.                     double n)
  67.         throws MathIllegalArgumentException {
  68.         if (n <= 0) {
  69.             throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL_BOUND_EXCLUDED,
  70.                                                    n, 0);
  71.         }

  72.         this.k = k;
  73.         this.m = m;
  74.         this.b = b;
  75.         this.q = q;
  76.         this.a = a;
  77.         oneOverN = 1 / n;
  78.     }

  79.     /** {@inheritDoc} */
  80.     @Override
  81.     public double value(double x) {
  82.         return value(m - x, k, b, q, a, oneOverN);
  83.     }

  84.     /**
  85.      * Parametric function where the input array contains the parameters of
  86.      * the {@link Logistic#Logistic(double,double,double,double,double,double)
  87.      * logistic function}, ordered as follows:
  88.      * <ul>
  89.      *  <li>k</li>
  90.      *  <li>m</li>
  91.      *  <li>b</li>
  92.      *  <li>q</li>
  93.      *  <li>a</li>
  94.      *  <li>n</li>
  95.      * </ul>
  96.      */
  97.     public static class Parametric implements ParametricUnivariateFunction {

  98.         /** Empty constructor.
  99.          * <p>
  100.          * This constructor is not strictly necessary, but it prevents spurious
  101.          * javadoc warnings with JDK 18 and later.
  102.          * </p>
  103.          * @since 3.0
  104.          */
  105.         public Parametric() { // NOPMD - unnecessary constructor added intentionally to make javadoc happy
  106.             // nothing to do
  107.         }

  108.         /**
  109.          * Computes the value of the sigmoid at {@code x}.
  110.          *
  111.          * @param x Value for which the function must be computed.
  112.          * @param param Values for {@code k}, {@code m}, {@code b}, {@code q},
  113.          * {@code a} and  {@code n}.
  114.          * @return the value of the function.
  115.          * @throws NullArgumentException if {@code param} is {@code null}.
  116.          * @throws MathIllegalArgumentException if the size of {@code param} is
  117.          * not 6.
  118.          * @throws MathIllegalArgumentException if {@code param[5] <= 0}.
  119.          */
  120.         @Override
  121.         public double value(double x, double ... param)
  122.             throws MathIllegalArgumentException, NullArgumentException {
  123.             validateParameters(param);
  124.             return Logistic.value(param[1] - x, param[0],
  125.                                   param[2], param[3],
  126.                                   param[4], 1 / param[5]);
  127.         }

  128.         /**
  129.          * Computes the value of the gradient at {@code x}.
  130.          * The components of the gradient vector are the partial
  131.          * derivatives of the function with respect to each of the
  132.          * <em>parameters</em>.
  133.          *
  134.          * @param x Value at which the gradient must be computed.
  135.          * @param param Values for {@code k}, {@code m}, {@code b}, {@code q},
  136.          * {@code a} and  {@code n}.
  137.          * @return the gradient vector at {@code x}.
  138.          * @throws NullArgumentException if {@code param} is {@code null}.
  139.          * @throws MathIllegalArgumentException if the size of {@code param} is
  140.          * not 6.
  141.          * @throws MathIllegalArgumentException if {@code param[5] <= 0}.
  142.          */
  143.         @Override
  144.         public double[] gradient(double x, double ... param)
  145.             throws MathIllegalArgumentException, NullArgumentException {
  146.             validateParameters(param);

  147.             final double b = param[2];
  148.             final double q = param[3];

  149.             final double mMinusX = param[1] - x;
  150.             final double oneOverN = 1 / param[5];
  151.             final double exp = FastMath.exp(b * mMinusX);
  152.             final double qExp = q * exp;
  153.             final double qExp1 = qExp + 1;
  154.             final double factor1 = (param[0] - param[4]) * oneOverN / FastMath.pow(qExp1, oneOverN);
  155.             final double factor2 = -factor1 / qExp1;

  156.             // Components of the gradient.
  157.             final double gk = Logistic.value(mMinusX, 1, b, q, 0, oneOverN);
  158.             final double gm = factor2 * b * qExp;
  159.             final double gb = factor2 * mMinusX * qExp;
  160.             final double gq = factor2 * exp;
  161.             final double ga = Logistic.value(mMinusX, 0, b, q, 1, oneOverN);
  162.             final double gn = factor1 * FastMath.log(qExp1) * oneOverN;

  163.             return new double[] { gk, gm, gb, gq, ga, gn };
  164.         }

  165.         /**
  166.          * Validates parameters to ensure they are appropriate for the evaluation of
  167.          * the {@link #value(double,double[])} and {@link #gradient(double,double[])}
  168.          * methods.
  169.          *
  170.          * @param param Values for {@code k}, {@code m}, {@code b}, {@code q},
  171.          * {@code a} and {@code n}.
  172.          * @throws NullArgumentException if {@code param} is {@code null}.
  173.          * @throws MathIllegalArgumentException if the size of {@code param} is
  174.          * not 6.
  175.          * @throws MathIllegalArgumentException if {@code param[5] <= 0}.
  176.          */
  177.         private void validateParameters(double[] param)
  178.             throws MathIllegalArgumentException, NullArgumentException {
  179.             MathUtils.checkNotNull(param);
  180.             MathUtils.checkDimension(param.length, 6);
  181.             if (param[5] <= 0) {
  182.                 throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL_BOUND_EXCLUDED,
  183.                                                        param[5], 0);
  184.             }
  185.         }
  186.     }

  187.     /**
  188.      * @param mMinusX {@code m - x}.
  189.      * @param k {@code k}.
  190.      * @param b {@code b}.
  191.      * @param q {@code q}.
  192.      * @param a {@code a}.
  193.      * @param oneOverN {@code 1 / n}.
  194.      * @return the value of the function.
  195.      */
  196.     private static double value(double mMinusX,
  197.                                 double k,
  198.                                 double b,
  199.                                 double q,
  200.                                 double a,
  201.                                 double oneOverN) {
  202.         return a + (k - a) / FastMath.pow(1 + q * FastMath.exp(b * mMinusX), oneOverN);
  203.     }

  204.     /** {@inheritDoc}
  205.      */
  206.     @Override
  207.     public <T extends Derivative<T>> T value(T t) {
  208.         return t.negate().add(m).multiply(b).exp().multiply(q).add(1).pow(oneOverN).reciprocal().multiply(k - a).add(a);
  209.     }

  210. }