AbstractCurveFitter.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      https://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. /*
  18.  * This is not the original file distributed by the Apache Software Foundation
  19.  * It has been modified by the Hipparchus project
  20.  */
  21. package org.hipparchus.fitting;

  22. import java.util.Collection;

  23. import org.hipparchus.analysis.MultivariateMatrixFunction;
  24. import org.hipparchus.analysis.MultivariateVectorFunction;
  25. import org.hipparchus.analysis.ParametricUnivariateFunction;
  26. import org.hipparchus.optim.nonlinear.vector.leastsquares.LeastSquaresOptimizer;
  27. import org.hipparchus.optim.nonlinear.vector.leastsquares.LeastSquaresProblem;
  28. import org.hipparchus.optim.nonlinear.vector.leastsquares.LevenbergMarquardtOptimizer;

  29. /**
  30.  * Base class that contains common code for fitting parametric univariate
  31.  * real functions <code>y = f(p<sub>i</sub>;x)</code>, where {@code x} is
  32.  * the independent variable and the <code>p<sub>i</sub></code> are the
  33.  * <em>parameters</em>.
  34.  * <br>
  35.  * A fitter will find the optimal values of the parameters by
  36.  * <em>fitting</em> the curve so it remains very close to a set of
  37.  * {@code N} observed points <code>(x<sub>k</sub>, y<sub>k</sub>)</code>,
  38.  * {@code 0 <= k < N}.
  39.  * <br>
  40.  * An algorithm usually performs the fit by finding the parameter
  41.  * values that minimizes the objective function
  42.  * <pre><code>
  43.  *  &sum;y<sub>k</sub> - f(x<sub>k</sub>)<sup>2</sup>,
  44.  * </code></pre>
  45.  * which is actually a least-squares problem.
  46.  * This class contains boilerplate code for calling the
  47.  * {@link #fit(Collection)} method for obtaining the parameters.
  48.  * The problem setup, such as the choice of optimization algorithm
  49.  * for fitting a specific function is delegated to subclasses.
  50.  *
  51.  */
  52. public abstract class AbstractCurveFitter {

  53.     /** Empty constructor.
  54.      * <p>
  55.      * This constructor is not strictly necessary, but it prevents spurious
  56.      * javadoc warnings with JDK 18 and later.
  57.      * </p>
  58.      * @since 3.0
  59.      */
  60.     protected AbstractCurveFitter() { // NOPMD - unnecessary constructor added intentionally to make javadoc happy
  61.         // nothing to do
  62.     }

  63.     /**
  64.      * Fits a curve.
  65.      * This method computes the coefficients of the curve that best
  66.      * fit the sample of observed points.
  67.      *
  68.      * @param points Observations.
  69.      * @return the fitted parameters.
  70.      */
  71.     public double[] fit(Collection<WeightedObservedPoint> points) {
  72.         // Perform the fit.
  73.         return getOptimizer().optimize(getProblem(points)).getPoint().toArray();
  74.     }

  75.     /**
  76.      * Creates an optimizer set up to fit the appropriate curve.
  77.      * <p>
  78.      * The default implementation uses a {@link LevenbergMarquardtOptimizer
  79.      * Levenberg-Marquardt} optimizer.
  80.      * </p>
  81.      * @return the optimizer to use for fitting the curve to the
  82.      * given {@code points}.
  83.      */
  84.     protected LeastSquaresOptimizer getOptimizer() {
  85.         return new LevenbergMarquardtOptimizer();
  86.     }

  87.     /**
  88.      * Creates a least squares problem corresponding to the appropriate curve.
  89.      *
  90.      * @param points Sample points.
  91.      * @return the least squares problem to use for fitting the curve to the
  92.      * given {@code points}.
  93.      */
  94.     protected abstract LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> points);

  95.     /**
  96.      * Vector function for computing function theoretical values.
  97.      */
  98.     protected static class TheoreticalValuesFunction {
  99.         /** Function to fit. */
  100.         private final ParametricUnivariateFunction f;
  101.         /** Observations. */
  102.         private final double[] points;

  103.         /** Simple constructor.
  104.          * @param f function to fit.
  105.          * @param observations Observations.
  106.          */
  107.         public TheoreticalValuesFunction(final ParametricUnivariateFunction f,
  108.                                          final Collection<WeightedObservedPoint> observations) {
  109.             this.f = f;

  110.             final int len = observations.size();
  111.             this.points = new double[len];
  112.             int i = 0;
  113.             for (WeightedObservedPoint obs : observations) {
  114.                 this.points[i++] = obs.getX();
  115.             }
  116.         }

  117.         /** Get model function value.
  118.          * @return the model function value
  119.          */
  120.         public MultivariateVectorFunction getModelFunction() {
  121.             return new MultivariateVectorFunction() {
  122.                 /** {@inheritDoc} */
  123.                 @Override
  124.                 public double[] value(double[] p) {
  125.                     final int len = points.length;
  126.                     final double[] values = new double[len];
  127.                     for (int i = 0; i < len; i++) {
  128.                         values[i] = f.value(points[i], p);
  129.                     }

  130.                     return values;
  131.                 }
  132.             };
  133.         }

  134.         /** Get model function Jacobian.
  135.          * @return the model function Jacobian
  136.          */
  137.         public MultivariateMatrixFunction getModelFunctionJacobian() {
  138.             return new MultivariateMatrixFunction() {
  139.                 /** {@inheritDoc} */
  140.                 @Override
  141.                 public double[][] value(double[] p) {
  142.                     final int len = points.length;
  143.                     final double[][] jacobian = new double[len][];
  144.                     for (int i = 0; i < len; i++) {
  145.                         jacobian[i] = f.gradient(points[i], p);
  146.                     }
  147.                     return jacobian;
  148.                 }
  149.             };
  150.         }
  151.     }
  152. }