LoessInterpolator.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      https://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. /*
  18.  * This is not the original file distributed by the Apache Software Foundation
  19.  * It has been modified by the Hipparchus project
  20.  */
  21. package org.hipparchus.analysis.interpolation;

  22. import java.io.Serializable;
  23. import java.util.Arrays;

  24. import org.hipparchus.analysis.polynomials.PolynomialSplineFunction;
  25. import org.hipparchus.exception.LocalizedCoreFormats;
  26. import org.hipparchus.exception.MathIllegalArgumentException;
  27. import org.hipparchus.util.FastMath;
  28. import org.hipparchus.util.MathArrays;
  29. import org.hipparchus.util.MathUtils;

  30. /**
  31.  * Implements the <a href="http://en.wikipedia.org/wiki/Local_regression">
  32.  * Local Regression Algorithm</a> (also Loess, Lowess) for interpolation of
  33.  * real univariate functions.
  34.  * <p>
  35.  * For reference, see
  36.  * <a href="http://amstat.tandfonline.com/doi/abs/10.1080/01621459.1979.10481038">
  37.  * William S. Cleveland - Robust Locally Weighted Regression and Smoothing
  38.  * Scatterplots</a></p>
  39.  * <p>
  40.  * This class implements both the loess method and serves as an interpolation
  41.  * adapter to it, allowing one to build a spline on the obtained loess fit.</p>
  42.  *
  43.  */
  44. public class LoessInterpolator
  45.     implements UnivariateInterpolator, Serializable {
  46.     /** Default value of the bandwidth parameter. */
  47.     public static final double DEFAULT_BANDWIDTH = 0.3;
  48.     /** Default value of the number of robustness iterations. */
  49.     public static final int DEFAULT_ROBUSTNESS_ITERS = 2;
  50.     /**
  51.      * Default value for accuracy.
  52.      */
  53.     public static final double DEFAULT_ACCURACY = 1e-12;
  54.     /** serializable version identifier. */
  55.     private static final long serialVersionUID = 5204927143605193821L;
  56.     /**
  57.      * The bandwidth parameter: when computing the loess fit at
  58.      * a particular point, this fraction of source points closest
  59.      * to the current point is taken into account for computing
  60.      * a least-squares regression.
  61.      * <p>
  62.      * A sensible value is usually 0.25 to 0.5.</p>
  63.      */
  64.     private final double bandwidth;
  65.     /**
  66.      * The number of robustness iterations parameter: this many
  67.      * robustness iterations are done.
  68.      * <p>
  69.      * A sensible value is usually 0 (just the initial fit without any
  70.      * robustness iterations) to 4.</p>
  71.      */
  72.     private final int robustnessIters;
  73.     /**
  74.      * If the median residual at a certain robustness iteration
  75.      * is less than this amount, no more iterations are done.
  76.      */
  77.     private final double accuracy;

  78.     /**
  79.      * Constructs a new {@link LoessInterpolator}
  80.      * with a bandwidth of {@link #DEFAULT_BANDWIDTH},
  81.      * {@link #DEFAULT_ROBUSTNESS_ITERS} robustness iterations
  82.      * and an accuracy of {#link #DEFAULT_ACCURACY}.
  83.      * See {@link #LoessInterpolator(double, int, double)} for an explanation of
  84.      * the parameters.
  85.      */
  86.     public LoessInterpolator() {
  87.         this.bandwidth = DEFAULT_BANDWIDTH;
  88.         this.robustnessIters = DEFAULT_ROBUSTNESS_ITERS;
  89.         this.accuracy = DEFAULT_ACCURACY;
  90.     }

  91.     /**
  92.      * Construct a new {@link LoessInterpolator}
  93.      * with given bandwidth and number of robustness iterations.
  94.      * <p>
  95.      * Calling this constructor is equivalent to calling {link {@link
  96.      * #LoessInterpolator(double, int, double) LoessInterpolator(bandwidth,
  97.      * robustnessIters, LoessInterpolator.DEFAULT_ACCURACY)}
  98.      * </p>
  99.      *
  100.      * @param bandwidth  when computing the loess fit at
  101.      * a particular point, this fraction of source points closest
  102.      * to the current point is taken into account for computing
  103.      * a least-squares regression.
  104.      * A sensible value is usually 0.25 to 0.5, the default value is
  105.      * {@link #DEFAULT_BANDWIDTH}.
  106.      * @param robustnessIters This many robustness iterations are done.
  107.      * A sensible value is usually 0 (just the initial fit without any
  108.      * robustness iterations) to 4, the default value is
  109.      * {@link #DEFAULT_ROBUSTNESS_ITERS}.

  110.      * @see #LoessInterpolator(double, int, double)
  111.      */
  112.     public LoessInterpolator(double bandwidth, int robustnessIters) {
  113.         this(bandwidth, robustnessIters, DEFAULT_ACCURACY);
  114.     }

  115.     /**
  116.      * Construct a new {@link LoessInterpolator}
  117.      * with given bandwidth, number of robustness iterations and accuracy.
  118.      *
  119.      * @param bandwidth  when computing the loess fit at
  120.      * a particular point, this fraction of source points closest
  121.      * to the current point is taken into account for computing
  122.      * a least-squares regression.
  123.      * A sensible value is usually 0.25 to 0.5, the default value is
  124.      * {@link #DEFAULT_BANDWIDTH}.
  125.      * @param robustnessIters This many robustness iterations are done.
  126.      * A sensible value is usually 0 (just the initial fit without any
  127.      * robustness iterations) to 4, the default value is
  128.      * {@link #DEFAULT_ROBUSTNESS_ITERS}.
  129.      * @param accuracy If the median residual at a certain robustness iteration
  130.      * is less than this amount, no more iterations are done.
  131.      * @throws MathIllegalArgumentException if bandwidth does not lie in the interval [0,1].
  132.      * @throws MathIllegalArgumentException if {@code robustnessIters} is negative.
  133.      * @see #LoessInterpolator(double, int)
  134.      */
  135.     public LoessInterpolator(double bandwidth, int robustnessIters, double accuracy)
  136.         throws MathIllegalArgumentException {
  137.         if (bandwidth < 0 ||
  138.             bandwidth > 1) {
  139.             throw new MathIllegalArgumentException(LocalizedCoreFormats.BANDWIDTH, bandwidth, 0, 1);
  140.         }
  141.         this.bandwidth = bandwidth;
  142.         if (robustnessIters < 0) {
  143.             throw new MathIllegalArgumentException(LocalizedCoreFormats.ROBUSTNESS_ITERATIONS, robustnessIters);
  144.         }
  145.         this.robustnessIters = robustnessIters;
  146.         this.accuracy = accuracy;
  147.     }

  148.     /**
  149.      * Compute an interpolating function by performing a loess fit
  150.      * on the data at the original abscissae and then building a cubic spline
  151.      * with a
  152.      * {@link org.hipparchus.analysis.interpolation.SplineInterpolator}
  153.      * on the resulting fit.
  154.      *
  155.      * @param xval the arguments for the interpolation points
  156.      * @param yval the values for the interpolation points
  157.      * @return A cubic spline built upon a loess fit to the data at the original abscissae
  158.      * @throws MathIllegalArgumentException if {@code xval} not sorted in
  159.      * strictly increasing order.
  160.      * @throws MathIllegalArgumentException if {@code xval} and {@code yval} have
  161.      * different sizes.
  162.      * @throws MathIllegalArgumentException if {@code xval} or {@code yval} has zero size.
  163.      * @throws MathIllegalArgumentException if any of the arguments and values are
  164.      * not finite real numbers.
  165.      * @throws MathIllegalArgumentException if the bandwidth is too small to
  166.      * accomodate the size of the input data (i.e. the bandwidth must be
  167.      * larger than 2/n).
  168.      */
  169.     @Override
  170.     public final PolynomialSplineFunction interpolate(final double[] xval,
  171.                                                       final double[] yval)
  172.         throws MathIllegalArgumentException {
  173.         return new SplineInterpolator().interpolate(xval, smooth(xval, yval));
  174.     }

  175.     /**
  176.      * Compute a weighted loess fit on the data at the original abscissae.
  177.      *
  178.      * @param xval Arguments for the interpolation points.
  179.      * @param yval Values for the interpolation points.
  180.      * @param weights point weights: coefficients by which the robustness weight
  181.      * of a point is multiplied.
  182.      * @return the values of the loess fit at corresponding original abscissae.
  183.      * @throws MathIllegalArgumentException if {@code xval} not sorted in
  184.      * strictly increasing order.
  185.      * @throws MathIllegalArgumentException if {@code xval} and {@code yval} have
  186.      * different sizes.
  187.      * @throws MathIllegalArgumentException if {@code xval} or {@code yval} has zero size.
  188.      * @throws MathIllegalArgumentException if any of the arguments and values are
  189.      not finite real numbers.
  190.      * @throws MathIllegalArgumentException if the bandwidth is too small to
  191.      * accomodate the size of the input data (i.e. the bandwidth must be
  192.      * larger than 2/n).
  193.      */
  194.     public final double[] smooth(final double[] xval, final double[] yval,
  195.                                  final double[] weights)
  196.         throws MathIllegalArgumentException {
  197.         if (xval.length != yval.length) {
  198.             throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH,
  199.                                                    xval.length, yval.length);
  200.         }

  201.         final int n = xval.length;

  202.         if (n == 0) {
  203.             throw new MathIllegalArgumentException(LocalizedCoreFormats.NO_DATA);
  204.         }

  205.         checkAllFiniteReal(xval);
  206.         checkAllFiniteReal(yval);
  207.         checkAllFiniteReal(weights);

  208.         MathArrays.checkOrder(xval);

  209.         if (n == 1) {
  210.             return new double[]{yval[0]};
  211.         }

  212.         if (n == 2) {
  213.             return new double[]{yval[0], yval[1]};
  214.         }

  215.         int bandwidthInPoints = (int) (bandwidth * n);

  216.         if (bandwidthInPoints < 2) {
  217.             throw new MathIllegalArgumentException(LocalizedCoreFormats.BANDWIDTH,
  218.                                                 bandwidthInPoints, 2, true);
  219.         }

  220.         final double[] res = new double[n];

  221.         final double[] residuals = new double[n];
  222.         final double[] sortedResiduals = new double[n];

  223.         final double[] robustnessWeights = new double[n];

  224.         // Do an initial fit and 'robustnessIters' robustness iterations.
  225.         // This is equivalent to doing 'robustnessIters+1' robustness iterations
  226.         // starting with all robustness weights set to 1.
  227.         Arrays.fill(robustnessWeights, 1);

  228.         for (int iter = 0; iter <= robustnessIters; ++iter) {
  229.             final int[] bandwidthInterval = {0, bandwidthInPoints - 1};
  230.             // At each x, compute a local weighted linear regression
  231.             for (int i = 0; i < n; ++i) {
  232.                 final double x = xval[i];

  233.                 // Find out the interval of source points on which
  234.                 // a regression is to be made.
  235.                 if (i > 0) {
  236.                     updateBandwidthInterval(xval, weights, i, bandwidthInterval);
  237.                 }

  238.                 final int ileft = bandwidthInterval[0];
  239.                 final int iright = bandwidthInterval[1];

  240.                 // Compute the point of the bandwidth interval that is
  241.                 // farthest from x
  242.                 final int edge;
  243.                 if (xval[i] - xval[ileft] > xval[iright] - xval[i]) {
  244.                     edge = ileft;
  245.                 } else {
  246.                     edge = iright;
  247.                 }

  248.                 // Compute a least-squares linear fit weighted by
  249.                 // the product of robustness weights and the tricube
  250.                 // weight function.
  251.                 // See http://en.wikipedia.org/wiki/Linear_regression
  252.                 // (section "Univariate linear case")
  253.                 // and http://en.wikipedia.org/wiki/Weighted_least_squares
  254.                 // (section "Weighted least squares")
  255.                 double sumWeights = 0;
  256.                 double sumX = 0;
  257.                 double sumXSquared = 0;
  258.                 double sumY = 0;
  259.                 double sumXY = 0;
  260.                 double denom = FastMath.abs(1.0 / (xval[edge] - x));
  261.                 for (int k = ileft; k <= iright; ++k) {
  262.                     final double xk   = xval[k];
  263.                     final double yk   = yval[k];
  264.                     final double dist = (k < i) ? x - xk : xk - x;
  265.                     final double w    = tricube(dist * denom) * robustnessWeights[k] * weights[k];
  266.                     final double xkw  = xk * w;
  267.                     sumWeights += w;
  268.                     sumX += xkw;
  269.                     sumXSquared += xk * xkw;
  270.                     sumY += yk * w;
  271.                     sumXY += yk * xkw;
  272.                 }

  273.                 final double meanX = sumX / sumWeights;
  274.                 final double meanY = sumY / sumWeights;
  275.                 final double meanXY = sumXY / sumWeights;
  276.                 final double meanXSquared = sumXSquared / sumWeights;

  277.                 final double beta;
  278.                 if (FastMath.sqrt(FastMath.abs(meanXSquared - meanX * meanX)) < accuracy) {
  279.                     beta = 0;
  280.                 } else {
  281.                     beta = (meanXY - meanX * meanY) / (meanXSquared - meanX * meanX);
  282.                 }

  283.                 final double alpha = meanY - beta * meanX;

  284.                 res[i] = beta * x + alpha;
  285.                 residuals[i] = FastMath.abs(yval[i] - res[i]);
  286.             }

  287.             // No need to recompute the robustness weights at the last
  288.             // iteration, they won't be needed anymore
  289.             if (iter == robustnessIters) {
  290.                 break;
  291.             }

  292.             // Recompute the robustness weights.

  293.             // Find the median residual.
  294.             // An arraycopy and a sort are completely tractable here,
  295.             // because the preceding loop is a lot more expensive
  296.             System.arraycopy(residuals, 0, sortedResiduals, 0, n);
  297.             Arrays.sort(sortedResiduals);
  298.             final double medianResidual = sortedResiduals[n / 2];

  299.             if (FastMath.abs(medianResidual) < accuracy) {
  300.                 break;
  301.             }

  302.             for (int i = 0; i < n; ++i) {
  303.                 final double arg = residuals[i] / (6 * medianResidual);
  304.                 if (arg >= 1) {
  305.                     robustnessWeights[i] = 0;
  306.                 } else {
  307.                     final double w = 1 - arg * arg;
  308.                     robustnessWeights[i] = w * w;
  309.                 }
  310.             }
  311.         }

  312.         return res;
  313.     }

  314.     /**
  315.      * Compute a loess fit on the data at the original abscissae.
  316.      *
  317.      * @param xval the arguments for the interpolation points
  318.      * @param yval the values for the interpolation points
  319.      * @return values of the loess fit at corresponding original abscissae
  320.      * @throws MathIllegalArgumentException if {@code xval} not sorted in
  321.      * strictly increasing order.
  322.      * @throws MathIllegalArgumentException if {@code xval} and {@code yval} have
  323.      * different sizes.
  324.      * @throws MathIllegalArgumentException if {@code xval} or {@code yval} has zero size.
  325.      * @throws MathIllegalArgumentException if any of the arguments and values are
  326.      * not finite real numbers.
  327.      * @throws MathIllegalArgumentException if the bandwidth is too small to
  328.      * accomodate the size of the input data (i.e. the bandwidth must be
  329.      * larger than 2/n).
  330.      */
  331.     public final double[] smooth(final double[] xval, final double[] yval)
  332.         throws MathIllegalArgumentException {
  333.         if (xval.length != yval.length) {
  334.             throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH,
  335.                                                    xval.length, yval.length);
  336.         }

  337.         final double[] unitWeights = new double[xval.length];
  338.         Arrays.fill(unitWeights, 1.0);

  339.         return smooth(xval, yval, unitWeights);
  340.     }

  341.     /**
  342.      * Given an index interval into xval that embraces a certain number of
  343.      * points closest to {@code xval[i-1]}, update the interval so that it
  344.      * embraces the same number of points closest to {@code xval[i]},
  345.      * ignoring zero weights.
  346.      *
  347.      * @param xval Arguments array.
  348.      * @param weights Weights array.
  349.      * @param i Index around which the new interval should be computed.
  350.      * @param bandwidthInterval a two-element array {left, right} such that:
  351.      * {@code (left==0 or xval[i] - xval[left-1] > xval[right] - xval[i])}
  352.      * and
  353.      * {@code (right==xval.length-1 or xval[right+1] - xval[i] > xval[i] - xval[left])}.
  354.      * The array will be updated.
  355.      */
  356.     private static void updateBandwidthInterval(final double[] xval, final double[] weights,
  357.                                                 final int i,
  358.                                                 final int[] bandwidthInterval) {
  359.         final int left = bandwidthInterval[0];
  360.         final int right = bandwidthInterval[1];

  361.         // The right edge should be adjusted if the next point to the right
  362.         // is closer to xval[i] than the leftmost point of the current interval
  363.         int nextRight = nextNonzero(weights, right);
  364.         if (nextRight < xval.length && xval[nextRight] - xval[i] < xval[i] - xval[left]) {
  365.             int nextLeft = nextNonzero(weights, bandwidthInterval[0]);
  366.             bandwidthInterval[0] = nextLeft;
  367.             bandwidthInterval[1] = nextRight;
  368.         }
  369.     }

  370.     /**
  371.      * Return the smallest index {@code j} such that
  372.      * {@code j > i && (j == weights.length || weights[j] != 0)}.
  373.      *
  374.      * @param weights Weights array.
  375.      * @param i Index from which to start search.
  376.      * @return the smallest compliant index.
  377.      */
  378.     private static int nextNonzero(final double[] weights, final int i) {
  379.         int j = i + 1;
  380.         while(j < weights.length && weights[j] == 0) {
  381.             ++j;
  382.         }
  383.         return j;
  384.     }

  385.     /**
  386.      * Compute the
  387.      * <a href="http://en.wikipedia.org/wiki/Local_regression#Weight_function">tricube</a>
  388.      * weight function
  389.      *
  390.      * @param x Argument.
  391.      * @return <code>(1 - |x|<sup>3</sup>)<sup>3</sup></code> for |x| &lt; 1, 0 otherwise.
  392.      */
  393.     private static double tricube(final double x) {
  394.         final double absX = FastMath.abs(x);
  395.         if (absX >= 1.0) {
  396.             return 0.0;
  397.         }
  398.         final double tmp = 1 - absX * absX * absX;
  399.         return tmp * tmp * tmp;
  400.     }

  401.     /**
  402.      * Check that all elements of an array are finite real numbers.
  403.      *
  404.      * @param values Values array.
  405.      * @throws org.hipparchus.exception.MathIllegalArgumentException
  406.      * if one of the values is not a finite real number.
  407.      */
  408.     private static void checkAllFiniteReal(final double[] values) {
  409.         for (double value : values) {
  410.             MathUtils.checkFinite(value);
  411.         }
  412.     }
  413. }