1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * https://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18 /* 19 * This is not the original file distributed by the Apache Software Foundation 20 * It has been modified by the Hipparchus project 21 */ 22 package org.hipparchus.analysis.interpolation; 23 24 import java.io.Serializable; 25 import java.util.Arrays; 26 27 import org.hipparchus.analysis.polynomials.PolynomialSplineFunction; 28 import org.hipparchus.exception.LocalizedCoreFormats; 29 import org.hipparchus.exception.MathIllegalArgumentException; 30 import org.hipparchus.util.FastMath; 31 import org.hipparchus.util.MathArrays; 32 import org.hipparchus.util.MathUtils; 33 34 /** 35 * Implements the <a href="http://en.wikipedia.org/wiki/Local_regression"> 36 * Local Regression Algorithm</a> (also Loess, Lowess) for interpolation of 37 * real univariate functions. 38 * <p> 39 * For reference, see 40 * <a href="http://amstat.tandfonline.com/doi/abs/10.1080/01621459.1979.10481038"> 41 * William S. Cleveland - Robust Locally Weighted Regression and Smoothing 42 * Scatterplots</a></p> 43 * <p> 44 * This class implements both the loess method and serves as an interpolation 45 * adapter to it, allowing one to build a spline on the obtained loess fit.</p> 46 * 47 */ 48 public class LoessInterpolator 49 implements UnivariateInterpolator, Serializable { 50 /** Default value of the bandwidth parameter. */ 51 public static final double DEFAULT_BANDWIDTH = 0.3; 52 /** Default value of the number of robustness iterations. */ 53 public static final int DEFAULT_ROBUSTNESS_ITERS = 2; 54 /** 55 * Default value for accuracy. 56 */ 57 public static final double DEFAULT_ACCURACY = 1e-12; 58 /** serializable version identifier. */ 59 private static final long serialVersionUID = 5204927143605193821L; 60 /** 61 * The bandwidth parameter: when computing the loess fit at 62 * a particular point, this fraction of source points closest 63 * to the current point is taken into account for computing 64 * a least-squares regression. 65 * <p> 66 * A sensible value is usually 0.25 to 0.5.</p> 67 */ 68 private final double bandwidth; 69 /** 70 * The number of robustness iterations parameter: this many 71 * robustness iterations are done. 72 * <p> 73 * A sensible value is usually 0 (just the initial fit without any 74 * robustness iterations) to 4.</p> 75 */ 76 private final int robustnessIters; 77 /** 78 * If the median residual at a certain robustness iteration 79 * is less than this amount, no more iterations are done. 80 */ 81 private final double accuracy; 82 83 /** 84 * Constructs a new {@link LoessInterpolator} 85 * with a bandwidth of {@link #DEFAULT_BANDWIDTH}, 86 * {@link #DEFAULT_ROBUSTNESS_ITERS} robustness iterations 87 * and an accuracy of {#link #DEFAULT_ACCURACY}. 88 * See {@link #LoessInterpolator(double, int, double)} for an explanation of 89 * the parameters. 90 */ 91 public LoessInterpolator() { 92 this.bandwidth = DEFAULT_BANDWIDTH; 93 this.robustnessIters = DEFAULT_ROBUSTNESS_ITERS; 94 this.accuracy = DEFAULT_ACCURACY; 95 } 96 97 /** 98 * Construct a new {@link LoessInterpolator} 99 * with given bandwidth and number of robustness iterations. 100 * <p> 101 * Calling this constructor is equivalent to calling {link {@link 102 * #LoessInterpolator(double, int, double) LoessInterpolator(bandwidth, 103 * robustnessIters, LoessInterpolator.DEFAULT_ACCURACY)} 104 * </p> 105 * 106 * @param bandwidth when computing the loess fit at 107 * a particular point, this fraction of source points closest 108 * to the current point is taken into account for computing 109 * a least-squares regression. 110 * A sensible value is usually 0.25 to 0.5, the default value is 111 * {@link #DEFAULT_BANDWIDTH}. 112 * @param robustnessIters This many robustness iterations are done. 113 * A sensible value is usually 0 (just the initial fit without any 114 * robustness iterations) to 4, the default value is 115 * {@link #DEFAULT_ROBUSTNESS_ITERS}. 116 117 * @see #LoessInterpolator(double, int, double) 118 */ 119 public LoessInterpolator(double bandwidth, int robustnessIters) { 120 this(bandwidth, robustnessIters, DEFAULT_ACCURACY); 121 } 122 123 /** 124 * Construct a new {@link LoessInterpolator} 125 * with given bandwidth, number of robustness iterations and accuracy. 126 * 127 * @param bandwidth when computing the loess fit at 128 * a particular point, this fraction of source points closest 129 * to the current point is taken into account for computing 130 * a least-squares regression. 131 * A sensible value is usually 0.25 to 0.5, the default value is 132 * {@link #DEFAULT_BANDWIDTH}. 133 * @param robustnessIters This many robustness iterations are done. 134 * A sensible value is usually 0 (just the initial fit without any 135 * robustness iterations) to 4, the default value is 136 * {@link #DEFAULT_ROBUSTNESS_ITERS}. 137 * @param accuracy If the median residual at a certain robustness iteration 138 * is less than this amount, no more iterations are done. 139 * @throws MathIllegalArgumentException if bandwidth does not lie in the interval [0,1]. 140 * @throws MathIllegalArgumentException if {@code robustnessIters} is negative. 141 * @see #LoessInterpolator(double, int) 142 */ 143 public LoessInterpolator(double bandwidth, int robustnessIters, double accuracy) 144 throws MathIllegalArgumentException { 145 if (bandwidth < 0 || 146 bandwidth > 1) { 147 throw new MathIllegalArgumentException(LocalizedCoreFormats.BANDWIDTH, bandwidth, 0, 1); 148 } 149 this.bandwidth = bandwidth; 150 if (robustnessIters < 0) { 151 throw new MathIllegalArgumentException(LocalizedCoreFormats.ROBUSTNESS_ITERATIONS, robustnessIters); 152 } 153 this.robustnessIters = robustnessIters; 154 this.accuracy = accuracy; 155 } 156 157 /** 158 * Compute an interpolating function by performing a loess fit 159 * on the data at the original abscissae and then building a cubic spline 160 * with a 161 * {@link org.hipparchus.analysis.interpolation.SplineInterpolator} 162 * on the resulting fit. 163 * 164 * @param xval the arguments for the interpolation points 165 * @param yval the values for the interpolation points 166 * @return A cubic spline built upon a loess fit to the data at the original abscissae 167 * @throws MathIllegalArgumentException if {@code xval} not sorted in 168 * strictly increasing order. 169 * @throws MathIllegalArgumentException if {@code xval} and {@code yval} have 170 * different sizes. 171 * @throws MathIllegalArgumentException if {@code xval} or {@code yval} has zero size. 172 * @throws MathIllegalArgumentException if any of the arguments and values are 173 * not finite real numbers. 174 * @throws MathIllegalArgumentException if the bandwidth is too small to 175 * accomodate the size of the input data (i.e. the bandwidth must be 176 * larger than 2/n). 177 */ 178 @Override 179 public final PolynomialSplineFunction interpolate(final double[] xval, 180 final double[] yval) 181 throws MathIllegalArgumentException { 182 return new SplineInterpolator().interpolate(xval, smooth(xval, yval)); 183 } 184 185 /** 186 * Compute a weighted loess fit on the data at the original abscissae. 187 * 188 * @param xval Arguments for the interpolation points. 189 * @param yval Values for the interpolation points. 190 * @param weights point weights: coefficients by which the robustness weight 191 * of a point is multiplied. 192 * @return the values of the loess fit at corresponding original abscissae. 193 * @throws MathIllegalArgumentException if {@code xval} not sorted in 194 * strictly increasing order. 195 * @throws MathIllegalArgumentException if {@code xval} and {@code yval} have 196 * different sizes. 197 * @throws MathIllegalArgumentException if {@code xval} or {@code yval} has zero size. 198 * @throws MathIllegalArgumentException if any of the arguments and values are 199 not finite real numbers. 200 * @throws MathIllegalArgumentException if the bandwidth is too small to 201 * accomodate the size of the input data (i.e. the bandwidth must be 202 * larger than 2/n). 203 */ 204 public final double[] smooth(final double[] xval, final double[] yval, 205 final double[] weights) 206 throws MathIllegalArgumentException { 207 if (xval.length != yval.length) { 208 throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH, 209 xval.length, yval.length); 210 } 211 212 final int n = xval.length; 213 214 if (n == 0) { 215 throw new MathIllegalArgumentException(LocalizedCoreFormats.NO_DATA); 216 } 217 218 checkAllFiniteReal(xval); 219 checkAllFiniteReal(yval); 220 checkAllFiniteReal(weights); 221 222 MathArrays.checkOrder(xval); 223 224 if (n == 1) { 225 return new double[]{yval[0]}; 226 } 227 228 if (n == 2) { 229 return new double[]{yval[0], yval[1]}; 230 } 231 232 int bandwidthInPoints = (int) (bandwidth * n); 233 234 if (bandwidthInPoints < 2) { 235 throw new MathIllegalArgumentException(LocalizedCoreFormats.BANDWIDTH, 236 bandwidthInPoints, 2, true); 237 } 238 239 final double[] res = new double[n]; 240 241 final double[] residuals = new double[n]; 242 final double[] sortedResiduals = new double[n]; 243 244 final double[] robustnessWeights = new double[n]; 245 246 // Do an initial fit and 'robustnessIters' robustness iterations. 247 // This is equivalent to doing 'robustnessIters+1' robustness iterations 248 // starting with all robustness weights set to 1. 249 Arrays.fill(robustnessWeights, 1); 250 251 for (int iter = 0; iter <= robustnessIters; ++iter) { 252 final int[] bandwidthInterval = {0, bandwidthInPoints - 1}; 253 // At each x, compute a local weighted linear regression 254 for (int i = 0; i < n; ++i) { 255 final double x = xval[i]; 256 257 // Find out the interval of source points on which 258 // a regression is to be made. 259 if (i > 0) { 260 updateBandwidthInterval(xval, weights, i, bandwidthInterval); 261 } 262 263 final int ileft = bandwidthInterval[0]; 264 final int iright = bandwidthInterval[1]; 265 266 // Compute the point of the bandwidth interval that is 267 // farthest from x 268 final int edge; 269 if (xval[i] - xval[ileft] > xval[iright] - xval[i]) { 270 edge = ileft; 271 } else { 272 edge = iright; 273 } 274 275 // Compute a least-squares linear fit weighted by 276 // the product of robustness weights and the tricube 277 // weight function. 278 // See http://en.wikipedia.org/wiki/Linear_regression 279 // (section "Univariate linear case") 280 // and http://en.wikipedia.org/wiki/Weighted_least_squares 281 // (section "Weighted least squares") 282 double sumWeights = 0; 283 double sumX = 0; 284 double sumXSquared = 0; 285 double sumY = 0; 286 double sumXY = 0; 287 double denom = FastMath.abs(1.0 / (xval[edge] - x)); 288 for (int k = ileft; k <= iright; ++k) { 289 final double xk = xval[k]; 290 final double yk = yval[k]; 291 final double dist = (k < i) ? x - xk : xk - x; 292 final double w = tricube(dist * denom) * robustnessWeights[k] * weights[k]; 293 final double xkw = xk * w; 294 sumWeights += w; 295 sumX += xkw; 296 sumXSquared += xk * xkw; 297 sumY += yk * w; 298 sumXY += yk * xkw; 299 } 300 301 final double meanX = sumX / sumWeights; 302 final double meanY = sumY / sumWeights; 303 final double meanXY = sumXY / sumWeights; 304 final double meanXSquared = sumXSquared / sumWeights; 305 306 final double beta; 307 if (FastMath.sqrt(FastMath.abs(meanXSquared - meanX * meanX)) < accuracy) { 308 beta = 0; 309 } else { 310 beta = (meanXY - meanX * meanY) / (meanXSquared - meanX * meanX); 311 } 312 313 final double alpha = meanY - beta * meanX; 314 315 res[i] = beta * x + alpha; 316 residuals[i] = FastMath.abs(yval[i] - res[i]); 317 } 318 319 // No need to recompute the robustness weights at the last 320 // iteration, they won't be needed anymore 321 if (iter == robustnessIters) { 322 break; 323 } 324 325 // Recompute the robustness weights. 326 327 // Find the median residual. 328 // An arraycopy and a sort are completely tractable here, 329 // because the preceding loop is a lot more expensive 330 System.arraycopy(residuals, 0, sortedResiduals, 0, n); 331 Arrays.sort(sortedResiduals); 332 final double medianResidual = sortedResiduals[n / 2]; 333 334 if (FastMath.abs(medianResidual) < accuracy) { 335 break; 336 } 337 338 for (int i = 0; i < n; ++i) { 339 final double arg = residuals[i] / (6 * medianResidual); 340 if (arg >= 1) { 341 robustnessWeights[i] = 0; 342 } else { 343 final double w = 1 - arg * arg; 344 robustnessWeights[i] = w * w; 345 } 346 } 347 } 348 349 return res; 350 } 351 352 /** 353 * Compute a loess fit on the data at the original abscissae. 354 * 355 * @param xval the arguments for the interpolation points 356 * @param yval the values for the interpolation points 357 * @return values of the loess fit at corresponding original abscissae 358 * @throws MathIllegalArgumentException if {@code xval} not sorted in 359 * strictly increasing order. 360 * @throws MathIllegalArgumentException if {@code xval} and {@code yval} have 361 * different sizes. 362 * @throws MathIllegalArgumentException if {@code xval} or {@code yval} has zero size. 363 * @throws MathIllegalArgumentException if any of the arguments and values are 364 * not finite real numbers. 365 * @throws MathIllegalArgumentException if the bandwidth is too small to 366 * accomodate the size of the input data (i.e. the bandwidth must be 367 * larger than 2/n). 368 */ 369 public final double[] smooth(final double[] xval, final double[] yval) 370 throws MathIllegalArgumentException { 371 if (xval.length != yval.length) { 372 throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH, 373 xval.length, yval.length); 374 } 375 376 final double[] unitWeights = new double[xval.length]; 377 Arrays.fill(unitWeights, 1.0); 378 379 return smooth(xval, yval, unitWeights); 380 } 381 382 /** 383 * Given an index interval into xval that embraces a certain number of 384 * points closest to {@code xval[i-1]}, update the interval so that it 385 * embraces the same number of points closest to {@code xval[i]}, 386 * ignoring zero weights. 387 * 388 * @param xval Arguments array. 389 * @param weights Weights array. 390 * @param i Index around which the new interval should be computed. 391 * @param bandwidthInterval a two-element array {left, right} such that: 392 * {@code (left==0 or xval[i] - xval[left-1] > xval[right] - xval[i])} 393 * and 394 * {@code (right==xval.length-1 or xval[right+1] - xval[i] > xval[i] - xval[left])}. 395 * The array will be updated. 396 */ 397 private static void updateBandwidthInterval(final double[] xval, final double[] weights, 398 final int i, 399 final int[] bandwidthInterval) { 400 final int left = bandwidthInterval[0]; 401 final int right = bandwidthInterval[1]; 402 403 // The right edge should be adjusted if the next point to the right 404 // is closer to xval[i] than the leftmost point of the current interval 405 int nextRight = nextNonzero(weights, right); 406 if (nextRight < xval.length && xval[nextRight] - xval[i] < xval[i] - xval[left]) { 407 int nextLeft = nextNonzero(weights, bandwidthInterval[0]); 408 bandwidthInterval[0] = nextLeft; 409 bandwidthInterval[1] = nextRight; 410 } 411 } 412 413 /** 414 * Return the smallest index {@code j} such that 415 * {@code j > i && (j == weights.length || weights[j] != 0)}. 416 * 417 * @param weights Weights array. 418 * @param i Index from which to start search. 419 * @return the smallest compliant index. 420 */ 421 private static int nextNonzero(final double[] weights, final int i) { 422 int j = i + 1; 423 while(j < weights.length && weights[j] == 0) { 424 ++j; 425 } 426 return j; 427 } 428 429 /** 430 * Compute the 431 * <a href="http://en.wikipedia.org/wiki/Local_regression#Weight_function">tricube</a> 432 * weight function 433 * 434 * @param x Argument. 435 * @return <code>(1 - |x|<sup>3</sup>)<sup>3</sup></code> for |x| < 1, 0 otherwise. 436 */ 437 private static double tricube(final double x) { 438 final double absX = FastMath.abs(x); 439 if (absX >= 1.0) { 440 return 0.0; 441 } 442 final double tmp = 1 - absX * absX * absX; 443 return tmp * tmp * tmp; 444 } 445 446 /** 447 * Check that all elements of an array are finite real numbers. 448 * 449 * @param values Values array. 450 * @throws org.hipparchus.exception.MathIllegalArgumentException 451 * if one of the values is not a finite real number. 452 */ 453 private static void checkAllFiniteReal(final double[] values) { 454 for (int i = 0; i < values.length; i++) { 455 MathUtils.checkFinite(values[i]); 456 } 457 } 458 }