View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  /*
19   * This is not the original file distributed by the Apache Software Foundation
20   * It has been modified by the Hipparchus project
21   */
22  
23  package org.hipparchus.optim.nonlinear.scalar.gradient;
24  
25  import org.hipparchus.exception.LocalizedCoreFormats;
26  import org.hipparchus.exception.MathIllegalStateException;
27  import org.hipparchus.exception.MathRuntimeException;
28  import org.hipparchus.optim.ConvergenceChecker;
29  import org.hipparchus.optim.OptimizationData;
30  import org.hipparchus.optim.PointValuePair;
31  import org.hipparchus.optim.nonlinear.scalar.GoalType;
32  import org.hipparchus.optim.nonlinear.scalar.GradientMultivariateOptimizer;
33  import org.hipparchus.optim.nonlinear.scalar.LineSearch;
34  
35  
36  /**
37   * Non-linear conjugate gradient optimizer.
38   * <br>
39   * This class supports both the Fletcher-Reeves and the Polak-Ribière
40   * update formulas for the conjugate search directions.
41   * It also supports optional preconditioning.
42   * <br>
43   * Constraints are not supported: the call to
44   * {@link #optimize(OptimizationData[]) optimize} will throw
45   * {@link MathRuntimeException} if bounds are passed to it.
46   *
47   */
48  public class NonLinearConjugateGradientOptimizer
49      extends GradientMultivariateOptimizer {
50      /** Update formula for the beta parameter. */
51      private final Formula updateFormula;
52      /** Preconditioner (may be null). */
53      private final Preconditioner preconditioner;
54      /** Line search algorithm. */
55      private final LineSearch line;
56  
57      /**
58       * Available choices of update formulas for the updating the parameter
59       * that is used to compute the successive conjugate search directions.
60       * For non-linear conjugate gradients, there are
61       * two formulas:
62       * <ul>
63       *   <li>Fletcher-Reeves formula</li>
64       *   <li>Polak-Ribière formula</li>
65       * </ul>
66       *
67       * On the one hand, the Fletcher-Reeves formula is guaranteed to converge
68       * if the start point is close enough of the optimum whether the
69       * Polak-Ribière formula may not converge in rare cases. On the
70       * other hand, the Polak-Ribière formula is often faster when it
71       * does converge. Polak-Ribière is often used.
72       *
73       */
74      public enum Formula {
75          /** Fletcher-Reeves formula. */
76          FLETCHER_REEVES,
77          /** Polak-Ribière formula. */
78          POLAK_RIBIERE
79      }
80  
81      /**
82       * Constructor with default tolerances for the line search (1e-8) and
83       * {@link IdentityPreconditioner preconditioner}.
84       *
85       * @param updateFormula formula to use for updating the &beta; parameter,
86       * must be one of {@link Formula#FLETCHER_REEVES} or
87       * {@link Formula#POLAK_RIBIERE}.
88       * @param checker Convergence checker.
89       */
90      public NonLinearConjugateGradientOptimizer(final Formula updateFormula,
91                                                 ConvergenceChecker<PointValuePair> checker) {
92          this(updateFormula,
93               checker,
94               1e-8,
95               1e-8,
96               1e-8,
97               new IdentityPreconditioner());
98      }
99  
100     /**
101      * Constructor with default {@link IdentityPreconditioner preconditioner}.
102      *
103      * @param updateFormula formula to use for updating the &beta; parameter,
104      * must be one of {@link Formula#FLETCHER_REEVES} or
105      * {@link Formula#POLAK_RIBIERE}.
106      * @param checker Convergence checker.
107      * @param relativeTolerance Relative threshold for line search.
108      * @param absoluteTolerance Absolute threshold for line search.
109      * @param initialBracketingRange Extent of the initial interval used to
110      * find an interval that brackets the optimum in order to perform the
111      * line search.
112      *
113      * @see LineSearch#LineSearch(org.hipparchus.optim.nonlinear.scalar.MultivariateOptimizer,double,double,double)
114      */
115     public NonLinearConjugateGradientOptimizer(final Formula updateFormula,
116                                                ConvergenceChecker<PointValuePair> checker,
117                                                double relativeTolerance,
118                                                double absoluteTolerance,
119                                                double initialBracketingRange) {
120         this(updateFormula,
121              checker,
122              relativeTolerance,
123              absoluteTolerance,
124              initialBracketingRange,
125              new IdentityPreconditioner());
126     }
127 
128     /** Simple constructor.
129      * @param updateFormula formula to use for updating the &beta; parameter,
130      * must be one of {@link Formula#FLETCHER_REEVES} or
131      * {@link Formula#POLAK_RIBIERE}.
132      * @param checker Convergence checker.
133      * @param preconditioner Preconditioner.
134      * @param relativeTolerance Relative threshold for line search.
135      * @param absoluteTolerance Absolute threshold for line search.
136      * @param initialBracketingRange Extent of the initial interval used to
137      * find an interval that brackets the optimum in order to perform the
138      * line search.
139      *
140      * @see LineSearch#LineSearch(org.hipparchus.optim.nonlinear.scalar.MultivariateOptimizer,double,double,double)
141      */
142     public NonLinearConjugateGradientOptimizer(final Formula updateFormula,
143                                                ConvergenceChecker<PointValuePair> checker,
144                                                double relativeTolerance,
145                                                double absoluteTolerance,
146                                                double initialBracketingRange,
147                                                final Preconditioner preconditioner) {
148         super(checker);
149 
150         this.updateFormula = updateFormula;
151         this.preconditioner = preconditioner;
152         line = new LineSearch(this,
153                               relativeTolerance,
154                               absoluteTolerance,
155                               initialBracketingRange);
156     }
157 
158     /**
159      * {@inheritDoc}
160      */
161     @Override
162     public PointValuePair optimize(OptimizationData... optData)
163         throws MathIllegalStateException {
164         // Set up base class and perform computation.
165         return super.optimize(optData);
166     }
167 
168     /** {@inheritDoc} */
169     @Override
170     protected PointValuePair doOptimize() {
171         final ConvergenceChecker<PointValuePair> checker = getConvergenceChecker();
172         final double[] point = getStartPoint();
173         final GoalType goal = getGoalType();
174         final int n = point.length;
175         double[] r = computeObjectiveGradient(point);
176         if (goal == GoalType.MINIMIZE) {
177             for (int i = 0; i < n; i++) {
178                 r[i] = -r[i];
179             }
180         }
181 
182         // Initial search direction.
183         double[] steepestDescent = preconditioner.precondition(point, r);
184         double[] searchDirection = steepestDescent.clone();
185 
186         double delta = 0;
187         for (int i = 0; i < n; ++i) {
188             delta += r[i] * searchDirection[i];
189         }
190 
191         PointValuePair current = null;
192         while (true) {
193             incrementIterationCount();
194 
195             final double objective = computeObjectiveValue(point);
196             PointValuePair previous = current;
197             current = new PointValuePair(point, objective);
198             if (previous != null && checker.converged(getIterations(), previous, current)) {
199                 // We have found an optimum.
200                 return current;
201             }
202 
203             final double step = line.search(point, searchDirection).getPoint();
204 
205             // Validate new point.
206             for (int i = 0; i < point.length; ++i) {
207                 point[i] += step * searchDirection[i];
208             }
209 
210             r = computeObjectiveGradient(point);
211             if (goal == GoalType.MINIMIZE) {
212                 for (int i = 0; i < n; ++i) {
213                     r[i] = -r[i];
214                 }
215             }
216 
217             // Compute beta.
218             final double deltaOld = delta;
219             final double[] newSteepestDescent = preconditioner.precondition(point, r);
220             delta = 0;
221             for (int i = 0; i < n; ++i) {
222                 delta += r[i] * newSteepestDescent[i];
223             }
224 
225             final double beta;
226             switch (updateFormula) {
227             case FLETCHER_REEVES:
228                 beta = delta / deltaOld;
229                 break;
230             case POLAK_RIBIERE:
231                 double deltaMid = 0;
232                 for (int i = 0; i < r.length; ++i) {
233                     deltaMid += r[i] * steepestDescent[i];
234                 }
235                 beta = (delta - deltaMid) / deltaOld;
236                 break;
237             default:
238                 // Should never happen.
239                 throw MathRuntimeException.createInternalError();
240             }
241             steepestDescent = newSteepestDescent;
242 
243             // Compute conjugate search direction.
244             if (getIterations() % n == 0 ||
245                 beta < 0) {
246                 // Break conjugation: reset search direction.
247                 searchDirection = steepestDescent.clone();
248             } else {
249                 // Compute new conjugate search direction.
250                 for (int i = 0; i < n; ++i) {
251                     searchDirection[i] = steepestDescent[i] + beta * searchDirection[i];
252                 }
253             }
254         }
255     }
256 
257     /**
258      * {@inheritDoc}
259      */
260     @Override
261     protected void parseOptimizationData(OptimizationData... optData) {
262         // Allow base class to register its own data.
263         super.parseOptimizationData(optData);
264 
265         checkParameters();
266     }
267 
268     /** Default identity preconditioner. */
269     public static class IdentityPreconditioner implements Preconditioner {
270 
271         /** Empty constructor.
272          * <p>
273          * This constructor is not strictly necessary, but it prevents spurious
274          * javadoc warnings with JDK 18 and later.
275          * </p>
276          * @since 3.0
277          */
278         public IdentityPreconditioner() { // NOPMD - unnecessary constructor added intentionally to make javadoc happy
279             // nothing to do
280         }
281 
282         /** {@inheritDoc} */
283         @Override
284         public double[] precondition(double[] variables, double[] r) {
285             return r.clone();
286         }
287 
288     }
289 
290     // Class is not used anymore (cf. MATH-1092). However, it might
291     // be interesting to create a class similar to "LineSearch", but
292     // that will take advantage that the model's gradient is available.
293 //     /**
294 //      * Internal class for line search.
295 //      * <p>
296 //      * The function represented by this class is the dot product of
297 //      * the objective function gradient and the search direction. Its
298 //      * value is zero when the gradient is orthogonal to the search
299 //      * direction, i.e. when the objective function value is a local
300 //      * extremum along the search direction.
301 //      * </p>
302 //      */
303 //     private class LineSearchFunction implements UnivariateFunction {
304 //         /** Current point. */
305 //         private final double[] currentPoint;
306 //         /** Search direction. */
307 //         private final double[] searchDirection;
308 
309 //         /**
310 //          * @param point Current point.
311 //          * @param direction Search direction.
312 //          */
313 //         public LineSearchFunction(double[] point,
314 //                                   double[] direction) {
315 //             currentPoint = point.clone();
316 //             searchDirection = direction.clone();
317 //         }
318 
319 //         /** {@inheritDoc} */
320 //         public double value(double x) {
321 //             // current point in the search direction
322 //             final double[] shiftedPoint = currentPoint.clone();
323 //             for (int i = 0; i < shiftedPoint.length; ++i) {
324 //                 shiftedPoint[i] += x * searchDirection[i];
325 //             }
326 
327 //             // gradient of the objective function
328 //             final double[] gradient = computeObjectiveGradient(shiftedPoint);
329 
330 //             // dot product with the search direction
331 //             double dotProduct = 0;
332 //             for (int i = 0; i < gradient.length; ++i) {
333 //                 dotProduct += gradient[i] * searchDirection[i];
334 //             }
335 
336 //             return dotProduct;
337 //         }
338 //     }
339 
340     /**
341      * @throws MathRuntimeException if bounds were passed to the
342      * {@link #optimize(OptimizationData[]) optimize} method.
343      */
344     private void checkParameters() {
345         if (getLowerBound() != null ||
346             getUpperBound() != null) {
347             throw new MathRuntimeException(LocalizedCoreFormats.CONSTRAINT);
348         }
349     }
350 }