SimpleCurveFitter.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      https://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. /*
  18.  * This is not the original file distributed by the Apache Software Foundation
  19.  * It has been modified by the Hipparchus project
  20.  */
  21. package org.hipparchus.fitting;

  22. import java.util.Collection;

  23. import org.hipparchus.analysis.ParametricUnivariateFunction;
  24. import org.hipparchus.linear.DiagonalMatrix;
  25. import org.hipparchus.optim.nonlinear.vector.leastsquares.LeastSquaresBuilder;
  26. import org.hipparchus.optim.nonlinear.vector.leastsquares.LeastSquaresProblem;

  27. /**
  28.  * Fits points to a user-defined {@link ParametricUnivariateFunction function}.
  29.  *
  30.  */
  31. public class SimpleCurveFitter extends AbstractCurveFitter {
  32.     /** Function to fit. */
  33.     private final ParametricUnivariateFunction function;
  34.     /** Initial guess for the parameters. */
  35.     private final double[] initialGuess;
  36.     /** Maximum number of iterations of the optimization algorithm. */
  37.     private final int maxIter;

  38.     /**
  39.      * Constructor used by the factory methods.
  40.      *
  41.      * @param function Function to fit.
  42.      * @param initialGuess Initial guess. Cannot be {@code null}. Its length must
  43.      * be consistent with the number of parameters of the {@code function} to fit.
  44.      * @param maxIter Maximum number of iterations of the optimization algorithm.
  45.      */
  46.     private SimpleCurveFitter(ParametricUnivariateFunction function, double[] initialGuess, int maxIter) {
  47.         this.function = function;
  48.         this.initialGuess = initialGuess.clone();
  49.         this.maxIter = maxIter;
  50.     }

  51.     /**
  52.      * Creates a curve fitter.
  53.      * The maximum number of iterations of the optimization algorithm is set
  54.      * to {@link Integer#MAX_VALUE}.
  55.      *
  56.      * @param f Function to fit.
  57.      * @param start Initial guess for the parameters.  Cannot be {@code null}.
  58.      * Its length must be consistent with the number of parameters of the
  59.      * function to fit.
  60.      * @return a curve fitter.
  61.      *
  62.      * @see #withStartPoint(double[])
  63.      * @see #withMaxIterations(int)
  64.      */
  65.     public static SimpleCurveFitter create(ParametricUnivariateFunction f,
  66.                                            double[] start) {
  67.         return new SimpleCurveFitter(f, start, Integer.MAX_VALUE);
  68.     }

  69.     /**
  70.      * Configure the start point (initial guess).
  71.      * @param newStart new start point (initial guess)
  72.      * @return a new instance.
  73.      */
  74.     public SimpleCurveFitter withStartPoint(double[] newStart) {
  75.         return new SimpleCurveFitter(function,
  76.                                      newStart.clone(),
  77.                                      maxIter);
  78.     }

  79.     /**
  80.      * Configure the maximum number of iterations.
  81.      * @param newMaxIter maximum number of iterations
  82.      * @return a new instance.
  83.      */
  84.     public SimpleCurveFitter withMaxIterations(int newMaxIter) {
  85.         return new SimpleCurveFitter(function,
  86.                                      initialGuess,
  87.                                      newMaxIter);
  88.     }

  89.     /** {@inheritDoc} */
  90.     @Override
  91.     protected LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> observations) {
  92.         // Prepare least-squares problem.
  93.         final int len = observations.size();
  94.         final double[] target  = new double[len];
  95.         final double[] weights = new double[len];

  96.         int count = 0;
  97.         for (WeightedObservedPoint obs : observations) {
  98.             target[count]  = obs.getY();
  99.             weights[count] = obs.getWeight();
  100.             ++count;
  101.         }

  102.         final AbstractCurveFitter.TheoreticalValuesFunction model
  103.             = new AbstractCurveFitter.TheoreticalValuesFunction(function,
  104.                                                                 observations);

  105.         // Create an optimizer for fitting the curve to the observed points.
  106.         return new LeastSquaresBuilder().
  107.                 maxEvaluations(Integer.MAX_VALUE).
  108.                 maxIterations(maxIter).
  109.                 start(initialGuess).
  110.                 target(target).
  111.                 weight(new DiagonalMatrix(weights)).
  112.                 model(model.getModelFunction(), model.getModelFunctionJacobian()).
  113.                 build();
  114.     }
  115. }