NonLinearConjugateGradientOptimizer.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      https://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. /*
  18.  * This is not the original file distributed by the Apache Software Foundation
  19.  * It has been modified by the Hipparchus project
  20.  */

  21. package org.hipparchus.optim.nonlinear.scalar.gradient;

  22. import org.hipparchus.exception.LocalizedCoreFormats;
  23. import org.hipparchus.exception.MathIllegalStateException;
  24. import org.hipparchus.exception.MathRuntimeException;
  25. import org.hipparchus.optim.ConvergenceChecker;
  26. import org.hipparchus.optim.OptimizationData;
  27. import org.hipparchus.optim.PointValuePair;
  28. import org.hipparchus.optim.nonlinear.scalar.GoalType;
  29. import org.hipparchus.optim.nonlinear.scalar.GradientMultivariateOptimizer;
  30. import org.hipparchus.optim.nonlinear.scalar.LineSearch;


  31. /**
  32.  * Non-linear conjugate gradient optimizer.
  33.  * <br>
  34.  * This class supports both the Fletcher-Reeves and the Polak-Ribière
  35.  * update formulas for the conjugate search directions.
  36.  * It also supports optional preconditioning.
  37.  * <br>
  38.  * Constraints are not supported: the call to
  39.  * {@link #optimize(OptimizationData[]) optimize} will throw
  40.  * {@link MathRuntimeException} if bounds are passed to it.
  41.  *
  42.  */
  43. public class NonLinearConjugateGradientOptimizer
  44.     extends GradientMultivariateOptimizer {
  45.     /** Update formula for the beta parameter. */
  46.     private final Formula updateFormula;
  47.     /** Preconditioner (may be null). */
  48.     private final Preconditioner preconditioner;
  49.     /** Line search algorithm. */
  50.     private final LineSearch line;

  51.     /**
  52.      * Available choices of update formulas for the updating the parameter
  53.      * that is used to compute the successive conjugate search directions.
  54.      * For non-linear conjugate gradients, there are
  55.      * two formulas:
  56.      * <ul>
  57.      *   <li>Fletcher-Reeves formula</li>
  58.      *   <li>Polak-Ribière formula</li>
  59.      * </ul>
  60.      *
  61.      * On the one hand, the Fletcher-Reeves formula is guaranteed to converge
  62.      * if the start point is close enough of the optimum whether the
  63.      * Polak-Ribière formula may not converge in rare cases. On the
  64.      * other hand, the Polak-Ribière formula is often faster when it
  65.      * does converge. Polak-Ribière is often used.
  66.      *
  67.      */
  68.     public enum Formula {
  69.         /** Fletcher-Reeves formula. */
  70.         FLETCHER_REEVES,
  71.         /** Polak-Ribière formula. */
  72.         POLAK_RIBIERE
  73.     }

  74.     /**
  75.      * Constructor with default tolerances for the line search (1e-8) and
  76.      * {@link IdentityPreconditioner preconditioner}.
  77.      *
  78.      * @param updateFormula formula to use for updating the &beta; parameter,
  79.      * must be one of {@link Formula#FLETCHER_REEVES} or
  80.      * {@link Formula#POLAK_RIBIERE}.
  81.      * @param checker Convergence checker.
  82.      */
  83.     public NonLinearConjugateGradientOptimizer(final Formula updateFormula,
  84.                                                ConvergenceChecker<PointValuePair> checker) {
  85.         this(updateFormula,
  86.              checker,
  87.              1e-8,
  88.              1e-8,
  89.              1e-8,
  90.              new IdentityPreconditioner());
  91.     }

  92.     /**
  93.      * Constructor with default {@link IdentityPreconditioner preconditioner}.
  94.      *
  95.      * @param updateFormula formula to use for updating the &beta; parameter,
  96.      * must be one of {@link Formula#FLETCHER_REEVES} or
  97.      * {@link Formula#POLAK_RIBIERE}.
  98.      * @param checker Convergence checker.
  99.      * @param relativeTolerance Relative threshold for line search.
  100.      * @param absoluteTolerance Absolute threshold for line search.
  101.      * @param initialBracketingRange Extent of the initial interval used to
  102.      * find an interval that brackets the optimum in order to perform the
  103.      * line search.
  104.      *
  105.      * @see LineSearch#LineSearch(org.hipparchus.optim.nonlinear.scalar.MultivariateOptimizer,double,double,double)
  106.      */
  107.     public NonLinearConjugateGradientOptimizer(final Formula updateFormula,
  108.                                                ConvergenceChecker<PointValuePair> checker,
  109.                                                double relativeTolerance,
  110.                                                double absoluteTolerance,
  111.                                                double initialBracketingRange) {
  112.         this(updateFormula,
  113.              checker,
  114.              relativeTolerance,
  115.              absoluteTolerance,
  116.              initialBracketingRange,
  117.              new IdentityPreconditioner());
  118.     }

  119.     /** Simple constructor.
  120.      * @param updateFormula formula to use for updating the &beta; parameter,
  121.      * must be one of {@link Formula#FLETCHER_REEVES} or
  122.      * {@link Formula#POLAK_RIBIERE}.
  123.      * @param checker Convergence checker.
  124.      * @param preconditioner Preconditioner.
  125.      * @param relativeTolerance Relative threshold for line search.
  126.      * @param absoluteTolerance Absolute threshold for line search.
  127.      * @param initialBracketingRange Extent of the initial interval used to
  128.      * find an interval that brackets the optimum in order to perform the
  129.      * line search.
  130.      *
  131.      * @see LineSearch#LineSearch(org.hipparchus.optim.nonlinear.scalar.MultivariateOptimizer,double,double,double)
  132.      */
  133.     public NonLinearConjugateGradientOptimizer(final Formula updateFormula,
  134.                                                ConvergenceChecker<PointValuePair> checker,
  135.                                                double relativeTolerance,
  136.                                                double absoluteTolerance,
  137.                                                double initialBracketingRange,
  138.                                                final Preconditioner preconditioner) {
  139.         super(checker);

  140.         this.updateFormula = updateFormula;
  141.         this.preconditioner = preconditioner;
  142.         line = new LineSearch(this,
  143.                               relativeTolerance,
  144.                               absoluteTolerance,
  145.                               initialBracketingRange);
  146.     }

  147.     /**
  148.      * {@inheritDoc}
  149.      */
  150.     @Override
  151.     public PointValuePair optimize(OptimizationData... optData)
  152.         throws MathIllegalStateException {
  153.         // Set up base class and perform computation.
  154.         return super.optimize(optData);
  155.     }

  156.     /** {@inheritDoc} */
  157.     @Override
  158.     protected PointValuePair doOptimize() {
  159.         final ConvergenceChecker<PointValuePair> checker = getConvergenceChecker();
  160.         final double[] point = getStartPoint();
  161.         final GoalType goal = getGoalType();
  162.         final int n = point.length;
  163.         double[] r = computeObjectiveGradient(point);
  164.         if (goal == GoalType.MINIMIZE) {
  165.             for (int i = 0; i < n; i++) {
  166.                 r[i] = -r[i];
  167.             }
  168.         }

  169.         // Initial search direction.
  170.         double[] steepestDescent = preconditioner.precondition(point, r);
  171.         double[] searchDirection = steepestDescent.clone();

  172.         double delta = 0;
  173.         for (int i = 0; i < n; ++i) {
  174.             delta += r[i] * searchDirection[i];
  175.         }

  176.         PointValuePair current = null;
  177.         while (true) {
  178.             incrementIterationCount();

  179.             final double objective = computeObjectiveValue(point);
  180.             PointValuePair previous = current;
  181.             current = new PointValuePair(point, objective);
  182.             if (previous != null && checker.converged(getIterations(), previous, current)) {
  183.                 // We have found an optimum.
  184.                 return current;
  185.             }

  186.             final double step = line.search(point, searchDirection).getPoint();

  187.             // Validate new point.
  188.             for (int i = 0; i < point.length; ++i) {
  189.                 point[i] += step * searchDirection[i];
  190.             }

  191.             r = computeObjectiveGradient(point);
  192.             if (goal == GoalType.MINIMIZE) {
  193.                 for (int i = 0; i < n; ++i) {
  194.                     r[i] = -r[i];
  195.                 }
  196.             }

  197.             // Compute beta.
  198.             final double deltaOld = delta;
  199.             final double[] newSteepestDescent = preconditioner.precondition(point, r);
  200.             delta = 0;
  201.             for (int i = 0; i < n; ++i) {
  202.                 delta += r[i] * newSteepestDescent[i];
  203.             }

  204.             final double beta;
  205.             switch (updateFormula) {
  206.             case FLETCHER_REEVES:
  207.                 beta = delta / deltaOld;
  208.                 break;
  209.             case POLAK_RIBIERE:
  210.                 double deltaMid = 0;
  211.                 for (int i = 0; i < r.length; ++i) {
  212.                     deltaMid += r[i] * steepestDescent[i];
  213.                 }
  214.                 beta = (delta - deltaMid) / deltaOld;
  215.                 break;
  216.             default:
  217.                 // Should never happen.
  218.                 throw MathRuntimeException.createInternalError();
  219.             }
  220.             steepestDescent = newSteepestDescent;

  221.             // Compute conjugate search direction.
  222.             if (getIterations() % n == 0 ||
  223.                 beta < 0) {
  224.                 // Break conjugation: reset search direction.
  225.                 searchDirection = steepestDescent.clone();
  226.             } else {
  227.                 // Compute new conjugate search direction.
  228.                 for (int i = 0; i < n; ++i) {
  229.                     searchDirection[i] = steepestDescent[i] + beta * searchDirection[i];
  230.                 }
  231.             }
  232.         }
  233.     }

  234.     /**
  235.      * {@inheritDoc}
  236.      */
  237.     @Override
  238.     protected void parseOptimizationData(OptimizationData... optData) {
  239.         // Allow base class to register its own data.
  240.         super.parseOptimizationData(optData);

  241.         checkParameters();
  242.     }

  243.     /** Default identity preconditioner. */
  244.     public static class IdentityPreconditioner implements Preconditioner {

  245.         /** Empty constructor.
  246.          * <p>
  247.          * This constructor is not strictly necessary, but it prevents spurious
  248.          * javadoc warnings with JDK 18 and later.
  249.          * </p>
  250.          * @since 3.0
  251.          */
  252.         public IdentityPreconditioner() { // NOPMD - unnecessary constructor added intentionally to make javadoc happy
  253.             // nothing to do
  254.         }

  255.         /** {@inheritDoc} */
  256.         @Override
  257.         public double[] precondition(double[] variables, double[] r) {
  258.             return r.clone();
  259.         }

  260.     }

  261.     // Class is not used anymore (cf. MATH-1092). However, it might
  262.     // be interesting to create a class similar to "LineSearch", but
  263.     // that will take advantage that the model's gradient is available.
  264. //     /**
  265. //      * Internal class for line search.
  266. //      * <p>
  267. //      * The function represented by this class is the dot product of
  268. //      * the objective function gradient and the search direction. Its
  269. //      * value is zero when the gradient is orthogonal to the search
  270. //      * direction, i.e. when the objective function value is a local
  271. //      * extremum along the search direction.
  272. //      * </p>
  273. //      */
  274. //     private class LineSearchFunction implements UnivariateFunction {
  275. //         /** Current point. */
  276. //         private final double[] currentPoint;
  277. //         /** Search direction. */
  278. //         private final double[] searchDirection;

  279. //         /**
  280. //          * @param point Current point.
  281. //          * @param direction Search direction.
  282. //          */
  283. //         public LineSearchFunction(double[] point,
  284. //                                   double[] direction) {
  285. //             currentPoint = point.clone();
  286. //             searchDirection = direction.clone();
  287. //         }

  288. //         /** {@inheritDoc} */
  289. //         public double value(double x) {
  290. //             // current point in the search direction
  291. //             final double[] shiftedPoint = currentPoint.clone();
  292. //             for (int i = 0; i < shiftedPoint.length; ++i) {
  293. //                 shiftedPoint[i] += x * searchDirection[i];
  294. //             }

  295. //             // gradient of the objective function
  296. //             final double[] gradient = computeObjectiveGradient(shiftedPoint);

  297. //             // dot product with the search direction
  298. //             double dotProduct = 0;
  299. //             for (int i = 0; i < gradient.length; ++i) {
  300. //                 dotProduct += gradient[i] * searchDirection[i];
  301. //             }

  302. //             return dotProduct;
  303. //         }
  304. //     }

  305.     /**
  306.      * @throws MathRuntimeException if bounds were passed to the
  307.      * {@link #optimize(OptimizationData[]) optimize} method.
  308.      */
  309.     private void checkParameters() {
  310.         if (getLowerBound() != null ||
  311.             getUpperBound() != null) {
  312.             throw new MathRuntimeException(LocalizedCoreFormats.CONSTRAINT);
  313.         }
  314.     }
  315. }