GaussianCurveFitter.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      https://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. /*
  18.  * This is not the original file distributed by the Apache Software Foundation
  19.  * It has been modified by the Hipparchus project
  20.  */
  21. package org.hipparchus.fitting;

  22. import java.util.ArrayList;
  23. import java.util.Collection;
  24. import java.util.Comparator;
  25. import java.util.List;

  26. import org.hipparchus.analysis.function.Gaussian;
  27. import org.hipparchus.exception.LocalizedCoreFormats;
  28. import org.hipparchus.exception.MathIllegalArgumentException;
  29. import org.hipparchus.linear.DiagonalMatrix;
  30. import org.hipparchus.optim.nonlinear.vector.leastsquares.LeastSquaresBuilder;
  31. import org.hipparchus.optim.nonlinear.vector.leastsquares.LeastSquaresProblem;
  32. import org.hipparchus.util.FastMath;
  33. import org.hipparchus.util.MathUtils;

  34. /**
  35.  * Fits points to a {@link
  36.  * org.hipparchus.analysis.function.Gaussian.Parametric Gaussian}
  37.  * function.
  38.  * <br>
  39.  * The {@link #withStartPoint(double[]) initial guess values} must be passed
  40.  * in the following order:
  41.  * <ul>
  42.  *  <li>Normalization</li>
  43.  *  <li>Mean</li>
  44.  *  <li>Sigma</li>
  45.  * </ul>
  46.  * The optimal values will be returned in the same order.
  47.  *
  48.  * <p>
  49.  * Usage example:
  50.  * <pre>
  51.  *   WeightedObservedPoints obs = new WeightedObservedPoints();
  52.  *   obs.add(4.0254623,  531026.0);
  53.  *   obs.add(4.03128248, 984167.0);
  54.  *   obs.add(4.03839603, 1887233.0);
  55.  *   obs.add(4.04421621, 2687152.0);
  56.  *   obs.add(4.05132976, 3461228.0);
  57.  *   obs.add(4.05326982, 3580526.0);
  58.  *   obs.add(4.05779662, 3439750.0);
  59.  *   obs.add(4.0636168,  2877648.0);
  60.  *   obs.add(4.06943698, 2175960.0);
  61.  *   obs.add(4.07525716, 1447024.0);
  62.  *   obs.add(4.08237071, 717104.0);
  63.  *   obs.add(4.08366408, 620014.0);
  64.  *   double[] parameters = GaussianCurveFitter.create().fit(obs.toList());
  65.  * </pre>
  66.  *
  67.  */
  68. public class GaussianCurveFitter extends AbstractCurveFitter {
  69.     /** Parametric function to be fitted. */
  70.     private static final Gaussian.Parametric FUNCTION = new Gaussian.Parametric() {
  71.             /** {@inheritDoc} */
  72.             @Override
  73.             public double value(double x, double ... p) {
  74.                 double v = Double.POSITIVE_INFINITY;
  75.                 try {
  76.                     v = super.value(x, p);
  77.                 } catch (MathIllegalArgumentException e) { // NOPMD
  78.                     // Do nothing.
  79.                 }
  80.                 return v;
  81.             }

  82.             /** {@inheritDoc} */
  83.             @Override
  84.             public double[] gradient(double x, double ... p) {
  85.                 double[] v = { Double.POSITIVE_INFINITY,
  86.                                Double.POSITIVE_INFINITY,
  87.                                Double.POSITIVE_INFINITY };
  88.                 try {
  89.                     v = super.gradient(x, p);
  90.                 } catch (MathIllegalArgumentException e) { // NOPMD
  91.                     // Do nothing.
  92.                 }
  93.                 return v;
  94.             }
  95.         };
  96.     /** Initial guess. */
  97.     private final double[] initialGuess;
  98.     /** Maximum number of iterations of the optimization algorithm. */
  99.     private final int maxIter;

  100.     /**
  101.      * Constructor used by the factory methods.
  102.      *
  103.      * @param initialGuess Initial guess. If set to {@code null}, the initial guess
  104.      * will be estimated using the {@link ParameterGuesser}.
  105.      * @param maxIter Maximum number of iterations of the optimization algorithm.
  106.      */
  107.     private GaussianCurveFitter(double[] initialGuess, int maxIter) {
  108.         this.initialGuess = initialGuess == null ? null : initialGuess.clone();
  109.         this.maxIter = maxIter;
  110.     }

  111.     /**
  112.      * Creates a default curve fitter.
  113.      * The initial guess for the parameters will be {@link ParameterGuesser}
  114.      * computed automatically, and the maximum number of iterations of the
  115.      * optimization algorithm is set to {@link Integer#MAX_VALUE}.
  116.      *
  117.      * @return a curve fitter.
  118.      *
  119.      * @see #withStartPoint(double[])
  120.      * @see #withMaxIterations(int)
  121.      */
  122.     public static GaussianCurveFitter create() {
  123.         return new GaussianCurveFitter(null, Integer.MAX_VALUE);
  124.     }

  125.     /**
  126.      * Configure the start point (initial guess).
  127.      * @param newStart new start point (initial guess)
  128.      * @return a new instance.
  129.      */
  130.     public GaussianCurveFitter withStartPoint(double[] newStart) {
  131.         return new GaussianCurveFitter(newStart.clone(),
  132.                                        maxIter);
  133.     }

  134.     /**
  135.      * Configure the maximum number of iterations.
  136.      * @param newMaxIter maximum number of iterations
  137.      * @return a new instance.
  138.      */
  139.     public GaussianCurveFitter withMaxIterations(int newMaxIter) {
  140.         return new GaussianCurveFitter(initialGuess,
  141.                                        newMaxIter);
  142.     }

  143.     /** {@inheritDoc} */
  144.     @Override
  145.     protected LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> observations) {

  146.         // Prepare least-squares problem.
  147.         final int len = observations.size();
  148.         final double[] target  = new double[len];
  149.         final double[] weights = new double[len];

  150.         int i = 0;
  151.         for (WeightedObservedPoint obs : observations) {
  152.             target[i]  = obs.getY();
  153.             weights[i] = obs.getWeight();
  154.             ++i;
  155.         }

  156.         final AbstractCurveFitter.TheoreticalValuesFunction model =
  157.                 new AbstractCurveFitter.TheoreticalValuesFunction(FUNCTION, observations);

  158.         final double[] startPoint = initialGuess != null ?
  159.             initialGuess :
  160.             // Compute estimation.
  161.             new ParameterGuesser(observations).guess();

  162.         // Return a new least squares problem set up to fit a Gaussian curve to the
  163.         // observed points.
  164.         return new LeastSquaresBuilder().
  165.                 maxEvaluations(Integer.MAX_VALUE).
  166.                 maxIterations(maxIter).
  167.                 start(startPoint).
  168.                 target(target).
  169.                 weight(new DiagonalMatrix(weights)).
  170.                 model(model.getModelFunction(), model.getModelFunctionJacobian()).
  171.                 build();

  172.     }

  173.     /**
  174.      * Guesses the parameters {@code norm}, {@code mean}, and {@code sigma}
  175.      * of a {@link org.hipparchus.analysis.function.Gaussian.Parametric}
  176.      * based on the specified observed points.
  177.      */
  178.     public static class ParameterGuesser {
  179.         /** Normalization factor. */
  180.         private final double norm;
  181.         /** Mean. */
  182.         private final double mean;
  183.         /** Standard deviation. */
  184.         private final double sigma;

  185.         /**
  186.          * Constructs instance with the specified observed points.
  187.          *
  188.          * @param observations Observed points from which to guess the
  189.          * parameters of the Gaussian.
  190.          * @throws org.hipparchus.exception.NullArgumentException if {@code observations} is
  191.          * {@code null}.
  192.          * @throws MathIllegalArgumentException if there are less than 3
  193.          * observations.
  194.          */
  195.         public ParameterGuesser(Collection<WeightedObservedPoint> observations) {
  196.             MathUtils.checkNotNull(observations);
  197.             if (observations.size() < 3) {
  198.                 throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL,
  199.                                                        observations.size(), 3);
  200.             }

  201.             final List<WeightedObservedPoint> sorted = sortObservations(observations);
  202.             final double[] params = basicGuess(sorted.toArray(new WeightedObservedPoint[0]));

  203.             norm = params[0];
  204.             mean = params[1];
  205.             sigma = params[2];
  206.         }

  207.         /**
  208.          * Gets an estimation of the parameters.
  209.          *
  210.          * @return the guessed parameters, in the following order:
  211.          * <ul>
  212.          *  <li>Normalization factor</li>
  213.          *  <li>Mean</li>
  214.          *  <li>Standard deviation</li>
  215.          * </ul>
  216.          */
  217.         public double[] guess() {
  218.             return new double[] { norm, mean, sigma };
  219.         }

  220.         /**
  221.          * Sort the observations.
  222.          *
  223.          * @param unsorted Input observations.
  224.          * @return the input observations, sorted.
  225.          */
  226.         private List<WeightedObservedPoint> sortObservations(Collection<WeightedObservedPoint> unsorted) {
  227.             final List<WeightedObservedPoint> observations = new ArrayList<>(unsorted);

  228.             final Comparator<WeightedObservedPoint> cmp = new Comparator<WeightedObservedPoint>() {
  229.                 /** {@inheritDoc} */
  230.                 @Override
  231.                 public int compare(WeightedObservedPoint p1,
  232.                                    WeightedObservedPoint p2) {
  233.                     if (p1 == null && p2 == null) {
  234.                         return 0;
  235.                     }
  236.                     if (p1 == null) {
  237.                         return -1;
  238.                     }
  239.                     if (p2 == null) {
  240.                         return 1;
  241.                     }
  242.                     int comp = Double.compare(p1.getX(), p2.getX());
  243.                     if (comp != 0) {
  244.                         return comp;
  245.                     }
  246.                     comp = Double.compare(p1.getY(), p2.getY());
  247.                     if (comp != 0) {
  248.                         return comp;
  249.                     }
  250.                     comp = Double.compare(p1.getWeight(), p2.getWeight());
  251.                     if (comp != 0) {
  252.                         return comp;
  253.                     }
  254.                     return 0;
  255.                 }
  256.             };

  257.             observations.sort(cmp);
  258.             return observations;
  259.         }

  260.         /**
  261.          * Guesses the parameters based on the specified observed points.
  262.          *
  263.          * @param points Observed points, sorted.
  264.          * @return the guessed parameters (normalization factor, mean and
  265.          * sigma).
  266.          */
  267.         private double[] basicGuess(WeightedObservedPoint[] points) {
  268.             final int maxYIdx = findMaxY(points);
  269.             final double n = points[maxYIdx].getY();
  270.             final double m = points[maxYIdx].getX();

  271.             double fwhmApprox;
  272.             try {
  273.                 final double halfY = n + ((m - n) / 2);
  274.                 final double fwhmX1 = interpolateXAtY(points, maxYIdx, -1, halfY);
  275.                 final double fwhmX2 = interpolateXAtY(points, maxYIdx, 1, halfY);
  276.                 fwhmApprox = fwhmX2 - fwhmX1;
  277.             } catch (MathIllegalArgumentException e) {
  278.                 // TODO: Exceptions should not be used for flow control.
  279.                 fwhmApprox = points[points.length - 1].getX() - points[0].getX();
  280.             }
  281.             final double s = fwhmApprox / (2 * FastMath.sqrt(2 * FastMath.log(2)));

  282.             return new double[] { n, m, s };
  283.         }

  284.         /**
  285.          * Finds index of point in specified points with the largest Y.
  286.          *
  287.          * @param points Points to search.
  288.          * @return the index in specified points array.
  289.          */
  290.         private int findMaxY(WeightedObservedPoint[] points) {
  291.             int maxYIdx = 0;
  292.             for (int i = 1; i < points.length; i++) {
  293.                 if (points[i].getY() > points[maxYIdx].getY()) {
  294.                     maxYIdx = i;
  295.                 }
  296.             }
  297.             return maxYIdx;
  298.         }

  299.         /**
  300.          * Interpolates using the specified points to determine X at the
  301.          * specified Y.
  302.          *
  303.          * @param points Points to use for interpolation.
  304.          * @param startIdx Index within points from which to start the search for
  305.          * interpolation bounds points.
  306.          * @param idxStep Index step for searching interpolation bounds points.
  307.          * @param y Y value for which X should be determined.
  308.          * @return the value of X for the specified Y.
  309.          * @throws MathIllegalArgumentException if {@code idxStep} is 0.
  310.          * @throws MathIllegalArgumentException if specified {@code y} is not within the
  311.          * range of the specified {@code points}.
  312.          */
  313.         private double interpolateXAtY(WeightedObservedPoint[] points,
  314.                                        int startIdx,
  315.                                        int idxStep,
  316.                                        double y)
  317.             throws MathIllegalArgumentException {
  318.             if (idxStep == 0) {
  319.                 throw new MathIllegalArgumentException(LocalizedCoreFormats.ZERO_NOT_ALLOWED);
  320.             }
  321.             final WeightedObservedPoint[] twoPoints
  322.                 = getInterpolationPointsForY(points, startIdx, idxStep, y);
  323.             final WeightedObservedPoint p1 = twoPoints[0];
  324.             final WeightedObservedPoint p2 = twoPoints[1];
  325.             if (p1.getY() == y) {
  326.                 return p1.getX();
  327.             }
  328.             if (p2.getY() == y) {
  329.                 return p2.getX();
  330.             }
  331.             return p1.getX() + (((y - p1.getY()) * (p2.getX() - p1.getX())) /
  332.                                 (p2.getY() - p1.getY()));
  333.         }

  334.         /**
  335.          * Gets the two bounding interpolation points from the specified points
  336.          * suitable for determining X at the specified Y.
  337.          *
  338.          * @param points Points to use for interpolation.
  339.          * @param startIdx Index within points from which to start search for
  340.          * interpolation bounds points.
  341.          * @param idxStep Index step for search for interpolation bounds points.
  342.          * @param y Y value for which X should be determined.
  343.          * @return the array containing two points suitable for determining X at
  344.          * the specified Y.
  345.          * @throws MathIllegalArgumentException if {@code idxStep} is 0.
  346.          * @throws MathIllegalArgumentException if specified {@code y} is not within the
  347.          * range of the specified {@code points}.
  348.          */
  349.         private WeightedObservedPoint[] getInterpolationPointsForY(WeightedObservedPoint[] points,
  350.                                                                    int startIdx,
  351.                                                                    int idxStep,
  352.                                                                    double y)
  353.             throws MathIllegalArgumentException {
  354.             if (idxStep == 0) {
  355.                 throw new MathIllegalArgumentException(LocalizedCoreFormats.ZERO_NOT_ALLOWED);
  356.             }
  357.             for (int i = startIdx;
  358.                  idxStep < 0 ? i + idxStep >= 0 : i + idxStep < points.length;
  359.                  i += idxStep) {
  360.                 final WeightedObservedPoint p1 = points[i];
  361.                 final WeightedObservedPoint p2 = points[i + idxStep];
  362.                 if (isBetween(y, p1.getY(), p2.getY())) {
  363.                     if (idxStep < 0) {
  364.                         return new WeightedObservedPoint[] { p2, p1 };
  365.                     } else {
  366.                         return new WeightedObservedPoint[] { p1, p2 };
  367.                     }
  368.                 }
  369.             }

  370.             // Boundaries are replaced by dummy values because the raised
  371.             // exception is caught and the message never displayed.
  372.             // TODO: Exceptions should not be used for flow control.
  373.             throw new MathIllegalArgumentException(LocalizedCoreFormats.OUT_OF_RANGE_SIMPLE,
  374.                                                    y, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY);
  375.         }

  376.         /**
  377.          * Determines whether a value is between two other values.
  378.          *
  379.          * @param value Value to test whether it is between {@code boundary1}
  380.          * and {@code boundary2}.
  381.          * @param boundary1 One end of the range.
  382.          * @param boundary2 Other end of the range.
  383.          * @return {@code true} if {@code value} is between {@code boundary1} and
  384.          * {@code boundary2} (inclusive), {@code false} otherwise.
  385.          */
  386.         private boolean isBetween(double value,
  387.                                   double boundary1,
  388.                                   double boundary2) {
  389.             return (value >= boundary1 && value <= boundary2) ||
  390.                 (value >= boundary2 && value <= boundary1);
  391.         }
  392.     }
  393. }