Gaussian.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      https://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. /*
  18.  * This is not the original file distributed by the Apache Software Foundation
  19.  * It has been modified by the Hipparchus project
  20.  */

  21. package org.hipparchus.analysis.function;

  22. import java.util.Arrays;

  23. import org.hipparchus.analysis.ParametricUnivariateFunction;
  24. import org.hipparchus.analysis.differentiation.Derivative;
  25. import org.hipparchus.analysis.differentiation.UnivariateDifferentiableFunction;
  26. import org.hipparchus.exception.LocalizedCoreFormats;
  27. import org.hipparchus.exception.MathIllegalArgumentException;
  28. import org.hipparchus.exception.NullArgumentException;
  29. import org.hipparchus.util.FastMath;
  30. import org.hipparchus.util.MathUtils;
  31. import org.hipparchus.util.Precision;

  32. /**
  33.  * <a href="http://en.wikipedia.org/wiki/Gaussian_function">
  34.  *  Gaussian</a> function.
  35.  *
  36.  */
  37. public class Gaussian implements UnivariateDifferentiableFunction {
  38.     /** Mean. */
  39.     private final double mean;
  40.     /** Inverse of the standard deviation. */
  41.     private final double is;
  42.     /** Inverse of twice the square of the standard deviation. */
  43.     private final double i2s2;
  44.     /** Normalization factor. */
  45.     private final double norm;

  46.     /**
  47.      * Gaussian with given normalization factor, mean and standard deviation.
  48.      *
  49.      * @param norm Normalization factor.
  50.      * @param mean Mean.
  51.      * @param sigma Standard deviation.
  52.      * @throws MathIllegalArgumentException if {@code sigma <= 0}.
  53.      */
  54.     public Gaussian(double norm,
  55.                     double mean,
  56.                     double sigma)
  57.         throws MathIllegalArgumentException {
  58.         if (sigma <= 0) {
  59.             throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL_BOUND_EXCLUDED,
  60.                                                    sigma, 0);
  61.         }

  62.         this.norm = norm;
  63.         this.mean = mean;
  64.         this.is   = 1 / sigma;
  65.         this.i2s2 = 0.5 * is * is;
  66.     }

  67.     /**
  68.      * Normalized gaussian with given mean and standard deviation.
  69.      *
  70.      * @param mean Mean.
  71.      * @param sigma Standard deviation.
  72.      * @throws MathIllegalArgumentException if {@code sigma <= 0}.
  73.      */
  74.     public Gaussian(double mean,
  75.                     double sigma)
  76.         throws MathIllegalArgumentException {
  77.         this(1 / (sigma * FastMath.sqrt(2 * Math.PI)), mean, sigma);
  78.     }

  79.     /**
  80.      * Normalized gaussian with zero mean and unit standard deviation.
  81.      */
  82.     public Gaussian() {
  83.         this(0, 1);
  84.     }

  85.     /** {@inheritDoc} */
  86.     @Override
  87.     public double value(double x) {
  88.         return value(x - mean, norm, i2s2);
  89.     }

  90.     /**
  91.      * Parametric function where the input array contains the parameters of
  92.      * the Gaussian, ordered as follows:
  93.      * <ul>
  94.      *  <li>Norm</li>
  95.      *  <li>Mean</li>
  96.      *  <li>Standard deviation</li>
  97.      * </ul>
  98.      */
  99.     public static class Parametric implements ParametricUnivariateFunction {

  100.         /** Empty constructor.
  101.          * <p>
  102.          * This constructor is not strictly necessary, but it prevents spurious
  103.          * javadoc warnings with JDK 18 and later.
  104.          * </p>
  105.          * @since 3.0
  106.          */
  107.         public Parametric() { // NOPMD - unnecessary constructor added intentionally to make javadoc happy
  108.             // nothing to do
  109.         }

  110.         /**
  111.          * Computes the value of the Gaussian at {@code x}.
  112.          *
  113.          * @param x Value for which the function must be computed.
  114.          * @param param Values of norm, mean and standard deviation.
  115.          * @return the value of the function.
  116.          * @throws NullArgumentException if {@code param} is {@code null}.
  117.          * @throws MathIllegalArgumentException if the size of {@code param} is
  118.          * not 3.
  119.          * @throws MathIllegalArgumentException if {@code param[2]} is negative.
  120.          */
  121.         @Override
  122.         public double value(double x, double ... param)
  123.             throws MathIllegalArgumentException, NullArgumentException {
  124.             validateParameters(param);

  125.             final double diff = x - param[1];
  126.             final double i2s2 = 1 / (2 * param[2] * param[2]);
  127.             return Gaussian.value(diff, param[0], i2s2);
  128.         }

  129.         /**
  130.          * Computes the value of the gradient at {@code x}.
  131.          * The components of the gradient vector are the partial
  132.          * derivatives of the function with respect to each of the
  133.          * <em>parameters</em> (norm, mean and standard deviation).
  134.          *
  135.          * @param x Value at which the gradient must be computed.
  136.          * @param param Values of norm, mean and standard deviation.
  137.          * @return the gradient vector at {@code x}.
  138.          * @throws NullArgumentException if {@code param} is {@code null}.
  139.          * @throws MathIllegalArgumentException if the size of {@code param} is
  140.          * not 3.
  141.          * @throws MathIllegalArgumentException if {@code param[2]} is negative.
  142.          */
  143.         @Override
  144.         public double[] gradient(double x, double ... param)
  145.             throws MathIllegalArgumentException, NullArgumentException {
  146.             validateParameters(param);

  147.             final double norm = param[0];
  148.             final double diff = x - param[1];
  149.             final double sigma = param[2];
  150.             final double i2s2 = 1 / (2 * sigma * sigma);

  151.             final double n = Gaussian.value(diff, 1, i2s2);
  152.             final double m = norm * n * 2 * i2s2 * diff;
  153.             final double s = m * diff / sigma;

  154.             return new double[] { n, m, s };
  155.         }

  156.         /**
  157.          * Validates parameters to ensure they are appropriate for the evaluation of
  158.          * the {@link #value(double,double[])} and {@link #gradient(double,double[])}
  159.          * methods.
  160.          *
  161.          * @param param Values of norm, mean and standard deviation.
  162.          * @throws NullArgumentException if {@code param} is {@code null}.
  163.          * @throws MathIllegalArgumentException if the size of {@code param} is
  164.          * not 3.
  165.          * @throws MathIllegalArgumentException if {@code param[2]} is negative.
  166.          */
  167.         private void validateParameters(double[] param)
  168.             throws MathIllegalArgumentException, NullArgumentException {
  169.             if (param == null) {
  170.                 throw new NullArgumentException();
  171.             }
  172.             MathUtils.checkDimension(param.length, 3);
  173.             if (param[2] <= 0) {
  174.                 throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL_BOUND_EXCLUDED,
  175.                                                        param[2], 0);
  176.             }
  177.         }
  178.     }

  179.     /**
  180.      * @param xMinusMean {@code x - mean}.
  181.      * @param norm Normalization factor.
  182.      * @param i2s2 Inverse of twice the square of the standard deviation.
  183.      * @return the value of the Gaussian at {@code x}.
  184.      */
  185.     private static double value(double xMinusMean,
  186.                                 double norm,
  187.                                 double i2s2) {
  188.         return norm * FastMath.exp(-xMinusMean * xMinusMean * i2s2);
  189.     }

  190.     /** {@inheritDoc}
  191.      */
  192.     @Override
  193.     public <T extends Derivative<T>> T value(T t)
  194.         throws MathIllegalArgumentException {

  195.         final double u = is * (t.getValue() - mean);
  196.         double[] f = new double[t.getOrder() + 1];

  197.         // the nth order derivative of the Gaussian has the form:
  198.         // dn(g(x)/dxn = (norm / s^n) P_n(u) exp(-u^2/2) with u=(x-m)/s
  199.         // where P_n(u) is a degree n polynomial with same parity as n
  200.         // P_0(u) = 1, P_1(u) = -u, P_2(u) = u^2 - 1, P_3(u) = -u^3 + 3 u...
  201.         // the general recurrence relation for P_n is:
  202.         // P_n(u) = P_(n-1)'(u) - u P_(n-1)(u)
  203.         // as per polynomial parity, we can store coefficients of both P_(n-1) and P_n in the same array
  204.         final double[] p = new double[f.length];
  205.         p[0] = 1;
  206.         final double u2 = u * u;
  207.         double coeff = norm * FastMath.exp(-0.5 * u2);
  208.         if (coeff <= Precision.SAFE_MIN) {
  209.             Arrays.fill(f, 0.0);
  210.         } else {
  211.             f[0] = coeff;
  212.             for (int n = 1; n < f.length; ++n) {

  213.                 // update and evaluate polynomial P_n(x)
  214.                 double v = 0;
  215.                 p[n] = -p[n - 1];
  216.                 for (int k = n; k >= 0; k -= 2) {
  217.                     v = v * u2 + p[k];
  218.                     if (k > 2) {
  219.                         p[k - 2] = (k - 1) * p[k - 1] - p[k - 3];
  220.                     } else if (k == 2) {
  221.                         p[0] = p[1];
  222.                     }
  223.                 }
  224.                 if ((n & 0x1) == 1) {
  225.                     v *= u;
  226.                 }

  227.                 coeff *= is;
  228.                 f[n] = coeff * v;

  229.             }
  230.         }

  231.         return t.compose(f);

  232.     }

  233. }