NonLinearConjugateGradientOptimizer.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* This is not the original file distributed by the Apache Software Foundation
* It has been modified by the Hipparchus project
*/
package org.hipparchus.optim.nonlinear.scalar.gradient;
import org.hipparchus.exception.LocalizedCoreFormats;
import org.hipparchus.exception.MathIllegalStateException;
import org.hipparchus.exception.MathRuntimeException;
import org.hipparchus.optim.ConvergenceChecker;
import org.hipparchus.optim.OptimizationData;
import org.hipparchus.optim.PointValuePair;
import org.hipparchus.optim.nonlinear.scalar.GoalType;
import org.hipparchus.optim.nonlinear.scalar.GradientMultivariateOptimizer;
import org.hipparchus.optim.nonlinear.scalar.LineSearch;
/**
* Non-linear conjugate gradient optimizer.
* <br>
* This class supports both the Fletcher-Reeves and the Polak-Ribière
* update formulas for the conjugate search directions.
* It also supports optional preconditioning.
* <br>
* Constraints are not supported: the call to
* {@link #optimize(OptimizationData[]) optimize} will throw
* {@link MathRuntimeException} if bounds are passed to it.
*
*/
public class NonLinearConjugateGradientOptimizer
extends GradientMultivariateOptimizer {
/** Update formula for the beta parameter. */
private final Formula updateFormula;
/** Preconditioner (may be null). */
private final Preconditioner preconditioner;
/** Line search algorithm. */
private final LineSearch line;
/**
* Available choices of update formulas for the updating the parameter
* that is used to compute the successive conjugate search directions.
* For non-linear conjugate gradients, there are
* two formulas:
* <ul>
* <li>Fletcher-Reeves formula</li>
* <li>Polak-Ribière formula</li>
* </ul>
*
* On the one hand, the Fletcher-Reeves formula is guaranteed to converge
* if the start point is close enough of the optimum whether the
* Polak-Ribière formula may not converge in rare cases. On the
* other hand, the Polak-Ribière formula is often faster when it
* does converge. Polak-Ribière is often used.
*
*/
public enum Formula {
/** Fletcher-Reeves formula. */
FLETCHER_REEVES,
/** Polak-Ribière formula. */
POLAK_RIBIERE
}
/**
* Constructor with default tolerances for the line search (1e-8) and
* {@link IdentityPreconditioner preconditioner}.
*
* @param updateFormula formula to use for updating the β parameter,
* must be one of {@link Formula#FLETCHER_REEVES} or
* {@link Formula#POLAK_RIBIERE}.
* @param checker Convergence checker.
*/
public NonLinearConjugateGradientOptimizer(final Formula updateFormula,
ConvergenceChecker<PointValuePair> checker) {
this(updateFormula,
checker,
1e-8,
1e-8,
1e-8,
new IdentityPreconditioner());
}
/**
* Constructor with default {@link IdentityPreconditioner preconditioner}.
*
* @param updateFormula formula to use for updating the β parameter,
* must be one of {@link Formula#FLETCHER_REEVES} or
* {@link Formula#POLAK_RIBIERE}.
* @param checker Convergence checker.
* @param relativeTolerance Relative threshold for line search.
* @param absoluteTolerance Absolute threshold for line search.
* @param initialBracketingRange Extent of the initial interval used to
* find an interval that brackets the optimum in order to perform the
* line search.
*
* @see LineSearch#LineSearch(org.hipparchus.optim.nonlinear.scalar.MultivariateOptimizer,double,double,double)
*/
public NonLinearConjugateGradientOptimizer(final Formula updateFormula,
ConvergenceChecker<PointValuePair> checker,
double relativeTolerance,
double absoluteTolerance,
double initialBracketingRange) {
this(updateFormula,
checker,
relativeTolerance,
absoluteTolerance,
initialBracketingRange,
new IdentityPreconditioner());
}
/** Simple constructor.
* @param updateFormula formula to use for updating the β parameter,
* must be one of {@link Formula#FLETCHER_REEVES} or
* {@link Formula#POLAK_RIBIERE}.
* @param checker Convergence checker.
* @param preconditioner Preconditioner.
* @param relativeTolerance Relative threshold for line search.
* @param absoluteTolerance Absolute threshold for line search.
* @param initialBracketingRange Extent of the initial interval used to
* find an interval that brackets the optimum in order to perform the
* line search.
*
* @see LineSearch#LineSearch(org.hipparchus.optim.nonlinear.scalar.MultivariateOptimizer,double,double,double)
*/
public NonLinearConjugateGradientOptimizer(final Formula updateFormula,
ConvergenceChecker<PointValuePair> checker,
double relativeTolerance,
double absoluteTolerance,
double initialBracketingRange,
final Preconditioner preconditioner) {
super(checker);
this.updateFormula = updateFormula;
this.preconditioner = preconditioner;
line = new LineSearch(this,
relativeTolerance,
absoluteTolerance,
initialBracketingRange);
}
/**
* {@inheritDoc}
*/
@Override
public PointValuePair optimize(OptimizationData... optData)
throws MathIllegalStateException {
// Set up base class and perform computation.
return super.optimize(optData);
}
/** {@inheritDoc} */
@Override
protected PointValuePair doOptimize() {
final ConvergenceChecker<PointValuePair> checker = getConvergenceChecker();
final double[] point = getStartPoint();
final GoalType goal = getGoalType();
final int n = point.length;
double[] r = computeObjectiveGradient(point);
if (goal == GoalType.MINIMIZE) {
for (int i = 0; i < n; i++) {
r[i] = -r[i];
}
}
// Initial search direction.
double[] steepestDescent = preconditioner.precondition(point, r);
double[] searchDirection = steepestDescent.clone();
double delta = 0;
for (int i = 0; i < n; ++i) {
delta += r[i] * searchDirection[i];
}
PointValuePair current = null;
while (true) {
incrementIterationCount();
final double objective = computeObjectiveValue(point);
PointValuePair previous = current;
current = new PointValuePair(point, objective);
if (previous != null && checker.converged(getIterations(), previous, current)) {
// We have found an optimum.
return current;
}
final double step = line.search(point, searchDirection).getPoint();
// Validate new point.
for (int i = 0; i < point.length; ++i) {
point[i] += step * searchDirection[i];
}
r = computeObjectiveGradient(point);
if (goal == GoalType.MINIMIZE) {
for (int i = 0; i < n; ++i) {
r[i] = -r[i];
}
}
// Compute beta.
final double deltaOld = delta;
final double[] newSteepestDescent = preconditioner.precondition(point, r);
delta = 0;
for (int i = 0; i < n; ++i) {
delta += r[i] * newSteepestDescent[i];
}
final double beta;
switch (updateFormula) {
case FLETCHER_REEVES:
beta = delta / deltaOld;
break;
case POLAK_RIBIERE:
double deltaMid = 0;
for (int i = 0; i < r.length; ++i) {
deltaMid += r[i] * steepestDescent[i];
}
beta = (delta - deltaMid) / deltaOld;
break;
default:
// Should never happen.
throw MathRuntimeException.createInternalError();
}
steepestDescent = newSteepestDescent;
// Compute conjugate search direction.
if (getIterations() % n == 0 ||
beta < 0) {
// Break conjugation: reset search direction.
searchDirection = steepestDescent.clone();
} else {
// Compute new conjugate search direction.
for (int i = 0; i < n; ++i) {
searchDirection[i] = steepestDescent[i] + beta * searchDirection[i];
}
}
}
}
/**
* {@inheritDoc}
*/
@Override
protected void parseOptimizationData(OptimizationData... optData) {
// Allow base class to register its own data.
super.parseOptimizationData(optData);
checkParameters();
}
/** Default identity preconditioner. */
public static class IdentityPreconditioner implements Preconditioner {
/** Empty constructor.
* <p>
* This constructor is not strictly necessary, but it prevents spurious
* javadoc warnings with JDK 18 and later.
* </p>
* @since 3.0
*/
public IdentityPreconditioner() { // NOPMD - unnecessary constructor added intentionally to make javadoc happy
// nothing to do
}
/** {@inheritDoc} */
@Override
public double[] precondition(double[] variables, double[] r) {
return r.clone();
}
}
// Class is not used anymore (cf. MATH-1092). However, it might
// be interesting to create a class similar to "LineSearch", but
// that will take advantage that the model's gradient is available.
// /**
// * Internal class for line search.
// * <p>
// * The function represented by this class is the dot product of
// * the objective function gradient and the search direction. Its
// * value is zero when the gradient is orthogonal to the search
// * direction, i.e. when the objective function value is a local
// * extremum along the search direction.
// * </p>
// */
// private class LineSearchFunction implements UnivariateFunction {
// /** Current point. */
// private final double[] currentPoint;
// /** Search direction. */
// private final double[] searchDirection;
// /**
// * @param point Current point.
// * @param direction Search direction.
// */
// public LineSearchFunction(double[] point,
// double[] direction) {
// currentPoint = point.clone();
// searchDirection = direction.clone();
// }
// /** {@inheritDoc} */
// public double value(double x) {
// // current point in the search direction
// final double[] shiftedPoint = currentPoint.clone();
// for (int i = 0; i < shiftedPoint.length; ++i) {
// shiftedPoint[i] += x * searchDirection[i];
// }
// // gradient of the objective function
// final double[] gradient = computeObjectiveGradient(shiftedPoint);
// // dot product with the search direction
// double dotProduct = 0;
// for (int i = 0; i < gradient.length; ++i) {
// dotProduct += gradient[i] * searchDirection[i];
// }
// return dotProduct;
// }
// }
/**
* @throws MathRuntimeException if bounds were passed to the
* {@link #optimize(OptimizationData[]) optimize} method.
*/
private void checkParameters() {
if (getLowerBound() != null ||
getUpperBound() != null) {
throw new MathRuntimeException(LocalizedCoreFormats.CONSTRAINT);
}
}
}