BrentOptimizer.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      https://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. /*
  18.  * This is not the original file distributed by the Apache Software Foundation
  19.  * It has been modified by the Hipparchus project
  20.  */
  21. package org.hipparchus.optim.univariate;

  22. import org.hipparchus.exception.LocalizedCoreFormats;
  23. import org.hipparchus.exception.MathIllegalArgumentException;
  24. import org.hipparchus.optim.ConvergenceChecker;
  25. import org.hipparchus.optim.nonlinear.scalar.GoalType;
  26. import org.hipparchus.util.FastMath;
  27. import org.hipparchus.util.Precision;

  28. /**
  29.  * For a function defined on some interval {@code (lo, hi)}, this class
  30.  * finds an approximation {@code x} to the point at which the function
  31.  * attains its minimum.
  32.  * It implements Richard Brent's algorithm (from his book "Algorithms for
  33.  * Minimization without Derivatives", p. 79) for finding minima of real
  34.  * univariate functions.
  35.  * <br>
  36.  * This code is an adaptation, partly based on the Python code from SciPy
  37.  * (module "optimize.py" v0.5); the original algorithm is also modified
  38.  * <ul>
  39.  *  <li>to use an initial guess provided by the user,</li>
  40.  *  <li>to ensure that the best point encountered is the one returned.</li>
  41.  * </ul>
  42.  *
  43.  */
  44. public class BrentOptimizer extends UnivariateOptimizer {
  45.     /**
  46.      * Golden section.
  47.      */
  48.     private static final double GOLDEN_SECTION = 0.5 * (3 - FastMath.sqrt(5));
  49.     /**
  50.      * Minimum relative tolerance.
  51.      */
  52.     private static final double MIN_RELATIVE_TOLERANCE = 2 * FastMath.ulp(1d);
  53.     /**
  54.      * Relative threshold.
  55.      */
  56.     private final double relativeThreshold;
  57.     /**
  58.      * Absolute threshold.
  59.      */
  60.     private final double absoluteThreshold;

  61.     /**
  62.      * The arguments are used implement the original stopping criterion
  63.      * of Brent's algorithm.
  64.      * {@code abs} and {@code rel} define a tolerance
  65.      * {@code tol = rel |x| + abs}. {@code rel} should be no smaller than
  66.      * <em>2 macheps</em> and preferably not much less than <em>sqrt(macheps)</em>,
  67.      * where <em>macheps</em> is the relative machine precision. {@code abs} must
  68.      * be positive.
  69.      *
  70.      * @param rel Relative threshold.
  71.      * @param abs Absolute threshold.
  72.      * @param checker Additional, user-defined, convergence checking
  73.      * procedure.
  74.      * @throws MathIllegalArgumentException if {@code abs <= 0}.
  75.      * @throws MathIllegalArgumentException if {@code rel < 2 * FastMath.ulp(1d)}.
  76.      */
  77.     public BrentOptimizer(double rel,
  78.                           double abs,
  79.                           ConvergenceChecker<UnivariatePointValuePair> checker) {
  80.         super(checker);

  81.         if (rel < MIN_RELATIVE_TOLERANCE) {
  82.             throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL,
  83.                                                    rel, MIN_RELATIVE_TOLERANCE);
  84.         }
  85.         if (abs <= 0) {
  86.             throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL_BOUND_EXCLUDED,
  87.                                                    abs, 0);
  88.         }

  89.         relativeThreshold = rel;
  90.         absoluteThreshold = abs;
  91.     }

  92.     /**
  93.      * The arguments are used for implementing the original stopping criterion
  94.      * of Brent's algorithm.
  95.      * {@code abs} and {@code rel} define a tolerance
  96.      * {@code tol = rel |x| + abs}. {@code rel} should be no smaller than
  97.      * <em>2 macheps</em> and preferably not much less than <em>sqrt(macheps)</em>,
  98.      * where <em>macheps</em> is the relative machine precision. {@code abs} must
  99.      * be positive.
  100.      *
  101.      * @param rel Relative threshold.
  102.      * @param abs Absolute threshold.
  103.      * @throws MathIllegalArgumentException if {@code abs <= 0}.
  104.      * @throws MathIllegalArgumentException if {@code rel < 2 * FastMath.ulp(1d)}.
  105.      */
  106.     public BrentOptimizer(double rel,
  107.                           double abs) {
  108.         this(rel, abs, null);
  109.     }

  110.     /** {@inheritDoc} */
  111.     @Override
  112.     protected UnivariatePointValuePair doOptimize() {
  113.         final boolean isMinim = getGoalType() == GoalType.MINIMIZE;
  114.         final double lo = getMin();
  115.         final double mid = getStartValue();
  116.         final double hi = getMax();

  117.         // Optional additional convergence criteria.
  118.         final ConvergenceChecker<UnivariatePointValuePair> checker
  119.             = getConvergenceChecker();

  120.         double a;
  121.         double b;
  122.         if (lo < hi) {
  123.             a = lo;
  124.             b = hi;
  125.         } else {
  126.             a = hi;
  127.             b = lo;
  128.         }

  129.         double x = mid;
  130.         double v = x;
  131.         double w = x;
  132.         double d = 0;
  133.         double e = 0;
  134.         double fx = computeObjectiveValue(x);
  135.         if (!isMinim) {
  136.             fx = -fx;
  137.         }
  138.         double fv = fx;
  139.         double fw = fx;

  140.         UnivariatePointValuePair previous = null;
  141.         UnivariatePointValuePair current
  142.             = new UnivariatePointValuePair(x, isMinim ? fx : -fx);
  143.         // Best point encountered so far (which is the initial guess).
  144.         UnivariatePointValuePair best = current;

  145.         while (true) {
  146.             final double m = 0.5 * (a + b);
  147.             final double tol1 = relativeThreshold * FastMath.abs(x) + absoluteThreshold;
  148.             final double tol2 = 2 * tol1;

  149.             // Default stopping criterion.
  150.             final boolean stop = FastMath.abs(x - m) <= tol2 - 0.5 * (b - a);
  151.             if (!stop) {
  152.                 double u;

  153.                 if (FastMath.abs(e) > tol1) { // Fit parabola.
  154.                     double r = (x - w) * (fx - fv);
  155.                     double q = (x - v) * (fx - fw);
  156.                     double p = (x - v) * q - (x - w) * r;
  157.                     q = 2 * (q - r);

  158.                     if (q > 0) {
  159.                         p = -p;
  160.                     } else {
  161.                         q = -q;
  162.                     }

  163.                     r = e;
  164.                     e = d;

  165.                     if (p > q * (a - x) &&
  166.                         p < q * (b - x) &&
  167.                         FastMath.abs(p) < FastMath.abs(0.5 * q * r)) {
  168.                         // Parabolic interpolation step.
  169.                         d = p / q;
  170.                         u = x + d;

  171.                         // f must not be evaluated too close to a or b.
  172.                         if (u - a < tol2 || b - u < tol2) {
  173.                             if (x <= m) {
  174.                                 d = tol1;
  175.                             } else {
  176.                                 d = -tol1;
  177.                             }
  178.                         }
  179.                     } else {
  180.                         // Golden section step.
  181.                         if (x < m) {
  182.                             e = b - x;
  183.                         } else {
  184.                             e = a - x;
  185.                         }
  186.                         d = GOLDEN_SECTION * e;
  187.                     }
  188.                 } else {
  189.                     // Golden section step.
  190.                     if (x < m) {
  191.                         e = b - x;
  192.                     } else {
  193.                         e = a - x;
  194.                     }
  195.                     d = GOLDEN_SECTION * e;
  196.                 }

  197.                 // Update by at least "tol1".
  198.                 if (FastMath.abs(d) < tol1) {
  199.                     if (d >= 0) {
  200.                         u = x + tol1;
  201.                     } else {
  202.                         u = x - tol1;
  203.                     }
  204.                 } else {
  205.                     u = x + d;
  206.                 }

  207.                 double fu = computeObjectiveValue(u);
  208.                 if (!isMinim) {
  209.                     fu = -fu;
  210.                 }

  211.                 // User-defined convergence checker.
  212.                 previous = current;
  213.                 current = new UnivariatePointValuePair(u, isMinim ? fu : -fu);
  214.                 best = best(best,
  215.                             best(previous,
  216.                                  current,
  217.                                  isMinim),
  218.                             isMinim);

  219.                 if (checker != null && checker.converged(getIterations(), previous, current)) {
  220.                     return best;
  221.                 }

  222.                 // Update a, b, v, w and x.
  223.                 if (fu <= fx) {
  224.                     if (u < x) {
  225.                         b = x;
  226.                     } else {
  227.                         a = x;
  228.                     }
  229.                     v = w;
  230.                     fv = fw;
  231.                     w = x;
  232.                     fw = fx;
  233.                     x = u;
  234.                     fx = fu;
  235.                 } else {
  236.                     if (u < x) {
  237.                         a = u;
  238.                     } else {
  239.                         b = u;
  240.                     }
  241.                     if (fu <= fw ||
  242.                         Precision.equals(w, x)) {
  243.                         v = w;
  244.                         fv = fw;
  245.                         w = u;
  246.                         fw = fu;
  247.                     } else if (fu <= fv ||
  248.                                Precision.equals(v, x) ||
  249.                                Precision.equals(v, w)) {
  250.                         v = u;
  251.                         fv = fu;
  252.                     }
  253.                 }
  254.             } else { // Default termination (Brent's criterion).
  255.                 return best(best,
  256.                             best(previous,
  257.                                  current,
  258.                                  isMinim),
  259.                             isMinim);
  260.             }

  261.             incrementIterationCount();
  262.         }
  263.     }

  264.     /**
  265.      * Selects the best of two points.
  266.      *
  267.      * @param a Point and value.
  268.      * @param b Point and value.
  269.      * @param isMinim {@code true} if the selected point must be the one with
  270.      * the lowest value.
  271.      * @return the best point, or {@code null} if {@code a} and {@code b} are
  272.      * both {@code null}. When {@code a} and {@code b} have the same function
  273.      * value, {@code a} is returned.
  274.      */
  275.     private UnivariatePointValuePair best(UnivariatePointValuePair a,
  276.                                           UnivariatePointValuePair b,
  277.                                           boolean isMinim) {
  278.         if (a == null) {
  279.             return b;
  280.         }
  281.         if (b == null) {
  282.             return a;
  283.         }

  284.         if (isMinim) {
  285.             return a.getValue() <= b.getValue() ? a : b;
  286.         } else {
  287.             return a.getValue() >= b.getValue() ? a : b;
  288.         }
  289.     }
  290. }