Logit.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      https://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. /*
  18.  * This is not the original file distributed by the Apache Software Foundation
  19.  * It has been modified by the Hipparchus project
  20.  */

  21. package org.hipparchus.analysis.function;

  22. import org.hipparchus.analysis.ParametricUnivariateFunction;
  23. import org.hipparchus.analysis.differentiation.Derivative;
  24. import org.hipparchus.analysis.differentiation.UnivariateDifferentiableFunction;
  25. import org.hipparchus.exception.MathIllegalArgumentException;
  26. import org.hipparchus.exception.NullArgumentException;
  27. import org.hipparchus.util.FastMath;
  28. import org.hipparchus.util.MathUtils;

  29. /**
  30.  * <a href="http://en.wikipedia.org/wiki/Logit">
  31.  *  Logit</a> function.
  32.  * It is the inverse of the {@link Sigmoid sigmoid} function.
  33.  *
  34.  */
  35. public class Logit implements UnivariateDifferentiableFunction {
  36.     /** Lower bound. */
  37.     private final double lo;
  38.     /** Higher bound. */
  39.     private final double hi;

  40.     /**
  41.      * Usual logit function, where the lower bound is 0 and the higher
  42.      * bound is 1.
  43.      */
  44.     public Logit() {
  45.         this(0, 1);
  46.     }

  47.     /**
  48.      * Logit function.
  49.      *
  50.      * @param lo Lower bound of the function domain.
  51.      * @param hi Higher bound of the function domain.
  52.      */
  53.     public Logit(double lo,
  54.                  double hi) {
  55.         this.lo = lo;
  56.         this.hi = hi;
  57.     }

  58.     /** {@inheritDoc} */
  59.     @Override
  60.     public double value(double x)
  61.         throws MathIllegalArgumentException {
  62.         return value(x, lo, hi);
  63.     }

  64.     /**
  65.      * Parametric function where the input array contains the parameters of
  66.      * the logit function, ordered as follows:
  67.      * <ul>
  68.      *  <li>Lower bound</li>
  69.      *  <li>Higher bound</li>
  70.      * </ul>
  71.      */
  72.     public static class Parametric implements ParametricUnivariateFunction {

  73.         /** Empty constructor.
  74.          * <p>
  75.          * This constructor is not strictly necessary, but it prevents spurious
  76.          * javadoc warnings with JDK 18 and later.
  77.          * </p>
  78.          * @since 3.0
  79.          */
  80.         public Parametric() { // NOPMD - unnecessary constructor added intentionally to make javadoc happy
  81.             // nothing to do
  82.         }

  83.         /**
  84.          * Computes the value of the logit at {@code x}.
  85.          *
  86.          * @param x Value for which the function must be computed.
  87.          * @param param Values of lower bound and higher bounds.
  88.          * @return the value of the function.
  89.          * @throws NullArgumentException if {@code param} is {@code null}.
  90.          * @throws MathIllegalArgumentException if the size of {@code param} is
  91.          * not 2.
  92.          */
  93.         @Override
  94.         public double value(double x, double ... param)
  95.             throws MathIllegalArgumentException, NullArgumentException {
  96.             validateParameters(param);
  97.             return Logit.value(x, param[0], param[1]);
  98.         }

  99.         /**
  100.          * Computes the value of the gradient at {@code x}.
  101.          * The components of the gradient vector are the partial
  102.          * derivatives of the function with respect to each of the
  103.          * <em>parameters</em> (lower bound and higher bound).
  104.          *
  105.          * @param x Value at which the gradient must be computed.
  106.          * @param param Values for lower and higher bounds.
  107.          * @return the gradient vector at {@code x}.
  108.          * @throws NullArgumentException if {@code param} is {@code null}.
  109.          * @throws MathIllegalArgumentException if the size of {@code param} is
  110.          * not 2.
  111.          */
  112.         @Override
  113.         public double[] gradient(double x, double ... param)
  114.             throws MathIllegalArgumentException, NullArgumentException {
  115.             validateParameters(param);

  116.             final double lo = param[0];
  117.             final double hi = param[1];

  118.             return new double[] { 1 / (lo - x), 1 / (hi - x) };
  119.         }

  120.         /**
  121.          * Validates parameters to ensure they are appropriate for the evaluation of
  122.          * the {@link #value(double,double[])} and {@link #gradient(double,double[])}
  123.          * methods.
  124.          *
  125.          * @param param Values for lower and higher bounds.
  126.          * @throws NullArgumentException if {@code param} is {@code null}.
  127.          * @throws MathIllegalArgumentException if the size of {@code param} is
  128.          * not 2.
  129.          */
  130.         private void validateParameters(double[] param)
  131.             throws MathIllegalArgumentException, NullArgumentException {
  132.             MathUtils.checkNotNull(param);
  133.             MathUtils.checkDimension(param.length, 2);
  134.         }
  135.     }

  136.     /**
  137.      * @param x Value at which to compute the logit.
  138.      * @param lo Lower bound.
  139.      * @param hi Higher bound.
  140.      * @return the value of the logit function at {@code x}.
  141.      * @throws MathIllegalArgumentException if {@code x < lo} or {@code x > hi}.
  142.      */
  143.     private static double value(double x,
  144.                                 double lo,
  145.                                 double hi)
  146.         throws MathIllegalArgumentException {
  147.         MathUtils.checkRangeInclusive(x, lo, hi);
  148.         return FastMath.log((x - lo) / (hi - x));
  149.     }

  150.     /** {@inheritDoc}
  151.      * @exception MathIllegalArgumentException if parameter is outside of function domain
  152.      */
  153.     @Override
  154.     public <T extends Derivative<T>> T value(T t)
  155.         throws MathIllegalArgumentException {
  156.         final double x = t.getValue();
  157.         MathUtils.checkRangeInclusive(x, lo, hi);
  158.         double[] f = new double[t.getOrder() + 1];

  159.         // function value
  160.         f[0] = FastMath.log((x - lo) / (hi - x));

  161.         if (Double.isInfinite(f[0])) {

  162.             if (f.length > 1) {
  163.                 f[1] = Double.POSITIVE_INFINITY;
  164.             }
  165.             // fill the array with infinities
  166.             // (for x close to lo the signs will flip between -inf and +inf,
  167.             //  for x close to hi the signs will always be +inf)
  168.             // this is probably overkill, since the call to compose at the end
  169.             // of the method will transform most infinities into NaN ...
  170.             for (int i = 2; i < f.length; ++i) {
  171.                 f[i] = f[i - 2];
  172.             }

  173.         } else {

  174.             // function derivatives
  175.             final double invL = 1.0 / (x - lo);
  176.             double xL = invL;
  177.             final double invH = 1.0 / (hi - x);
  178.             double xH = invH;
  179.             for (int i = 1; i < f.length; ++i) {
  180.                 f[i] = xL + xH;
  181.                 xL  *= -i * invL;
  182.                 xH  *=  i * invH;
  183.             }
  184.         }

  185.         return t.compose(f);
  186.     }
  187. }