View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  /*
19   * This is not the original file distributed by the Apache Software Foundation
20   * It has been modified by the Hipparchus project
21   */
22  package org.hipparchus.optim.univariate;
23  
24  import org.hipparchus.exception.LocalizedCoreFormats;
25  import org.hipparchus.exception.MathIllegalArgumentException;
26  import org.hipparchus.optim.ConvergenceChecker;
27  import org.hipparchus.optim.nonlinear.scalar.GoalType;
28  import org.hipparchus.util.FastMath;
29  import org.hipparchus.util.Precision;
30  
31  /**
32   * For a function defined on some interval {@code (lo, hi)}, this class
33   * finds an approximation {@code x} to the point at which the function
34   * attains its minimum.
35   * It implements Richard Brent's algorithm (from his book "Algorithms for
36   * Minimization without Derivatives", p. 79) for finding minima of real
37   * univariate functions.
38   * <br>
39   * This code is an adaptation, partly based on the Python code from SciPy
40   * (module "optimize.py" v0.5); the original algorithm is also modified
41   * <ul>
42   *  <li>to use an initial guess provided by the user,</li>
43   *  <li>to ensure that the best point encountered is the one returned.</li>
44   * </ul>
45   *
46   */
47  public class BrentOptimizer extends UnivariateOptimizer {
48      /**
49       * Golden section.
50       */
51      private static final double GOLDEN_SECTION = 0.5 * (3 - FastMath.sqrt(5));
52      /**
53       * Minimum relative tolerance.
54       */
55      private static final double MIN_RELATIVE_TOLERANCE = 2 * FastMath.ulp(1d);
56      /**
57       * Relative threshold.
58       */
59      private final double relativeThreshold;
60      /**
61       * Absolute threshold.
62       */
63      private final double absoluteThreshold;
64  
65      /**
66       * The arguments are used implement the original stopping criterion
67       * of Brent's algorithm.
68       * {@code abs} and {@code rel} define a tolerance
69       * {@code tol = rel |x| + abs}. {@code rel} should be no smaller than
70       * <em>2 macheps</em> and preferably not much less than <em>sqrt(macheps)</em>,
71       * where <em>macheps</em> is the relative machine precision. {@code abs} must
72       * be positive.
73       *
74       * @param rel Relative threshold.
75       * @param abs Absolute threshold.
76       * @param checker Additional, user-defined, convergence checking
77       * procedure.
78       * @throws MathIllegalArgumentException if {@code abs <= 0}.
79       * @throws MathIllegalArgumentException if {@code rel < 2 * FastMath.ulp(1d)}.
80       */
81      public BrentOptimizer(double rel,
82                            double abs,
83                            ConvergenceChecker<UnivariatePointValuePair> checker) {
84          super(checker);
85  
86          if (rel < MIN_RELATIVE_TOLERANCE) {
87              throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL,
88                                                     rel, MIN_RELATIVE_TOLERANCE);
89          }
90          if (abs <= 0) {
91              throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL_BOUND_EXCLUDED,
92                                                     abs, 0);
93          }
94  
95          relativeThreshold = rel;
96          absoluteThreshold = abs;
97      }
98  
99      /**
100      * The arguments are used for implementing the original stopping criterion
101      * of Brent's algorithm.
102      * {@code abs} and {@code rel} define a tolerance
103      * {@code tol = rel |x| + abs}. {@code rel} should be no smaller than
104      * <em>2 macheps</em> and preferably not much less than <em>sqrt(macheps)</em>,
105      * where <em>macheps</em> is the relative machine precision. {@code abs} must
106      * be positive.
107      *
108      * @param rel Relative threshold.
109      * @param abs Absolute threshold.
110      * @throws MathIllegalArgumentException if {@code abs <= 0}.
111      * @throws MathIllegalArgumentException if {@code rel < 2 * FastMath.ulp(1d)}.
112      */
113     public BrentOptimizer(double rel,
114                           double abs) {
115         this(rel, abs, null);
116     }
117 
118     /** {@inheritDoc} */
119     @Override
120     protected UnivariatePointValuePair doOptimize() {
121         final boolean isMinim = getGoalType() == GoalType.MINIMIZE;
122         final double lo = getMin();
123         final double mid = getStartValue();
124         final double hi = getMax();
125 
126         // Optional additional convergence criteria.
127         final ConvergenceChecker<UnivariatePointValuePair> checker
128             = getConvergenceChecker();
129 
130         double a;
131         double b;
132         if (lo < hi) {
133             a = lo;
134             b = hi;
135         } else {
136             a = hi;
137             b = lo;
138         }
139 
140         double x = mid;
141         double v = x;
142         double w = x;
143         double d = 0;
144         double e = 0;
145         double fx = computeObjectiveValue(x);
146         if (!isMinim) {
147             fx = -fx;
148         }
149         double fv = fx;
150         double fw = fx;
151 
152         UnivariatePointValuePair previous = null;
153         UnivariatePointValuePair current
154             = new UnivariatePointValuePair(x, isMinim ? fx : -fx);
155         // Best point encountered so far (which is the initial guess).
156         UnivariatePointValuePair best = current;
157 
158         while (true) {
159             final double m = 0.5 * (a + b);
160             final double tol1 = relativeThreshold * FastMath.abs(x) + absoluteThreshold;
161             final double tol2 = 2 * tol1;
162 
163             // Default stopping criterion.
164             final boolean stop = FastMath.abs(x - m) <= tol2 - 0.5 * (b - a);
165             if (!stop) {
166                 double u;
167 
168                 if (FastMath.abs(e) > tol1) { // Fit parabola.
169                     double r = (x - w) * (fx - fv);
170                     double q = (x - v) * (fx - fw);
171                     double p = (x - v) * q - (x - w) * r;
172                     q = 2 * (q - r);
173 
174                     if (q > 0) {
175                         p = -p;
176                     } else {
177                         q = -q;
178                     }
179 
180                     r = e;
181                     e = d;
182 
183                     if (p > q * (a - x) &&
184                         p < q * (b - x) &&
185                         FastMath.abs(p) < FastMath.abs(0.5 * q * r)) {
186                         // Parabolic interpolation step.
187                         d = p / q;
188                         u = x + d;
189 
190                         // f must not be evaluated too close to a or b.
191                         if (u - a < tol2 || b - u < tol2) {
192                             if (x <= m) {
193                                 d = tol1;
194                             } else {
195                                 d = -tol1;
196                             }
197                         }
198                     } else {
199                         // Golden section step.
200                         if (x < m) {
201                             e = b - x;
202                         } else {
203                             e = a - x;
204                         }
205                         d = GOLDEN_SECTION * e;
206                     }
207                 } else {
208                     // Golden section step.
209                     if (x < m) {
210                         e = b - x;
211                     } else {
212                         e = a - x;
213                     }
214                     d = GOLDEN_SECTION * e;
215                 }
216 
217                 // Update by at least "tol1".
218                 if (FastMath.abs(d) < tol1) {
219                     if (d >= 0) {
220                         u = x + tol1;
221                     } else {
222                         u = x - tol1;
223                     }
224                 } else {
225                     u = x + d;
226                 }
227 
228                 double fu = computeObjectiveValue(u);
229                 if (!isMinim) {
230                     fu = -fu;
231                 }
232 
233                 // User-defined convergence checker.
234                 previous = current;
235                 current = new UnivariatePointValuePair(u, isMinim ? fu : -fu);
236                 best = best(best,
237                             best(previous,
238                                  current,
239                                  isMinim),
240                             isMinim);
241 
242                 if (checker != null && checker.converged(getIterations(), previous, current)) {
243                     return best;
244                 }
245 
246                 // Update a, b, v, w and x.
247                 if (fu <= fx) {
248                     if (u < x) {
249                         b = x;
250                     } else {
251                         a = x;
252                     }
253                     v = w;
254                     fv = fw;
255                     w = x;
256                     fw = fx;
257                     x = u;
258                     fx = fu;
259                 } else {
260                     if (u < x) {
261                         a = u;
262                     } else {
263                         b = u;
264                     }
265                     if (fu <= fw ||
266                         Precision.equals(w, x)) {
267                         v = w;
268                         fv = fw;
269                         w = u;
270                         fw = fu;
271                     } else if (fu <= fv ||
272                                Precision.equals(v, x) ||
273                                Precision.equals(v, w)) {
274                         v = u;
275                         fv = fu;
276                     }
277                 }
278             } else { // Default termination (Brent's criterion).
279                 return best(best,
280                             best(previous,
281                                  current,
282                                  isMinim),
283                             isMinim);
284             }
285 
286             incrementIterationCount();
287         }
288     }
289 
290     /**
291      * Selects the best of two points.
292      *
293      * @param a Point and value.
294      * @param b Point and value.
295      * @param isMinim {@code true} if the selected point must be the one with
296      * the lowest value.
297      * @return the best point, or {@code null} if {@code a} and {@code b} are
298      * both {@code null}. When {@code a} and {@code b} have the same function
299      * value, {@code a} is returned.
300      */
301     private UnivariatePointValuePair best(UnivariatePointValuePair a,
302                                           UnivariatePointValuePair b,
303                                           boolean isMinim) {
304         if (a == null) {
305             return b;
306         }
307         if (b == null) {
308             return a;
309         }
310 
311         if (isMinim) {
312             return a.getValue() <= b.getValue() ? a : b;
313         } else {
314             return a.getValue() >= b.getValue() ? a : b;
315         }
316     }
317 }