PowellOptimizer.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      https://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. /*
  18.  * This is not the original file distributed by the Apache Software Foundation
  19.  * It has been modified by the Hipparchus project
  20.  */
  21. package org.hipparchus.optim.nonlinear.scalar.noderiv;

  22. import org.hipparchus.exception.LocalizedCoreFormats;
  23. import org.hipparchus.exception.MathIllegalArgumentException;
  24. import org.hipparchus.exception.MathRuntimeException;
  25. import org.hipparchus.optim.ConvergenceChecker;
  26. import org.hipparchus.optim.OptimizationData;
  27. import org.hipparchus.optim.PointValuePair;
  28. import org.hipparchus.optim.nonlinear.scalar.GoalType;
  29. import org.hipparchus.optim.nonlinear.scalar.LineSearch;
  30. import org.hipparchus.optim.nonlinear.scalar.MultivariateOptimizer;
  31. import org.hipparchus.optim.univariate.UnivariatePointValuePair;
  32. import org.hipparchus.util.FastMath;

  33. /**
  34.  * Powell's algorithm.
  35.  * This code is translated and adapted from the Python version of this
  36.  * algorithm (as implemented in module {@code optimize.py} v0.5 of
  37.  * <em>SciPy</em>).
  38.  * <br>
  39.  * The default stopping criterion is based on the differences of the
  40.  * function value between two successive iterations. It is however possible
  41.  * to define a custom convergence checker that might terminate the algorithm
  42.  * earlier.
  43.  * <br>
  44.  * Line search is performed by the {@link LineSearch} class.
  45.  * <br>
  46.  * Constraints are not supported: the call to
  47.  * {@link #optimize(OptimizationData...)}  optimize} will throw
  48.  * {@link MathRuntimeException} if bounds are passed to it.
  49.  * In order to impose simple constraints, the objective function must be
  50.  * wrapped in an adapter like
  51.  * {@link org.hipparchus.optim.nonlinear.scalar.MultivariateFunctionMappingAdapter
  52.  * MultivariateFunctionMappingAdapter} or
  53.  * {@link org.hipparchus.optim.nonlinear.scalar.MultivariateFunctionPenaltyAdapter
  54.  * MultivariateFunctionPenaltyAdapter}.
  55.  *
  56.  */
  57. public class PowellOptimizer
  58.     extends MultivariateOptimizer {
  59.     /**
  60.      * Minimum relative tolerance.
  61.      */
  62.     private static final double MIN_RELATIVE_TOLERANCE = 2 * FastMath.ulp(1d);
  63.     /**
  64.      * Relative threshold.
  65.      */
  66.     private final double relativeThreshold;
  67.     /**
  68.      * Absolute threshold.
  69.      */
  70.     private final double absoluteThreshold;
  71.     /**
  72.      * Line search.
  73.      */
  74.     private final LineSearch line;

  75.     /**
  76.      * This constructor allows to specify a user-defined convergence checker,
  77.      * in addition to the parameters that control the default convergence
  78.      * checking procedure.
  79.      * <br>
  80.      * The internal line search tolerances are set to the square-root of their
  81.      * corresponding value in the multivariate optimizer.
  82.      *
  83.      * @param rel Relative threshold.
  84.      * @param abs Absolute threshold.
  85.      * @param checker Convergence checker.
  86.      * @throws MathIllegalArgumentException if {@code abs <= 0}.
  87.      * @throws MathIllegalArgumentException if {@code rel < 2 * FastMath.ulp(1d)}.
  88.      */
  89.     public PowellOptimizer(double rel,
  90.                            double abs,
  91.                            ConvergenceChecker<PointValuePair> checker) {
  92.         this(rel, abs, FastMath.sqrt(rel), FastMath.sqrt(abs), checker);
  93.     }

  94.     /**
  95.      * This constructor allows to specify a user-defined convergence checker,
  96.      * in addition to the parameters that control the default convergence
  97.      * checking procedure and the line search tolerances.
  98.      *
  99.      * @param rel Relative threshold for this optimizer.
  100.      * @param abs Absolute threshold for this optimizer.
  101.      * @param lineRel Relative threshold for the internal line search optimizer.
  102.      * @param lineAbs Absolute threshold for the internal line search optimizer.
  103.      * @param checker Convergence checker.
  104.      * @throws MathIllegalArgumentException if {@code abs <= 0}.
  105.      * @throws MathIllegalArgumentException if {@code rel < 2 * FastMath.ulp(1d)}.
  106.      */
  107.     public PowellOptimizer(double rel,
  108.                            double abs,
  109.                            double lineRel,
  110.                            double lineAbs,
  111.                            ConvergenceChecker<PointValuePair> checker) {
  112.         super(checker);

  113.         if (rel < MIN_RELATIVE_TOLERANCE) {
  114.             throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL,
  115.                                                    rel, MIN_RELATIVE_TOLERANCE);
  116.         }
  117.         if (abs <= 0) {
  118.             throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL_BOUND_EXCLUDED,
  119.                                                    abs, 0);
  120.         }
  121.         relativeThreshold = rel;
  122.         absoluteThreshold = abs;

  123.         // Create the line search optimizer.
  124.         line = new LineSearch(this,
  125.                               lineRel,
  126.                               lineAbs,
  127.                               1d);
  128.     }

  129.     /**
  130.      * The parameters control the default convergence checking procedure.
  131.      * <br>
  132.      * The internal line search tolerances are set to the square-root of their
  133.      * corresponding value in the multivariate optimizer.
  134.      *
  135.      * @param rel Relative threshold.
  136.      * @param abs Absolute threshold.
  137.      * @throws MathIllegalArgumentException if {@code abs <= 0}.
  138.      * @throws MathIllegalArgumentException if {@code rel < 2 * FastMath.ulp(1d)}.
  139.      */
  140.     public PowellOptimizer(double rel,
  141.                            double abs) {
  142.         this(rel, abs, null);
  143.     }

  144.     /**
  145.      * Builds an instance with the default convergence checking procedure.
  146.      *
  147.      * @param rel Relative threshold.
  148.      * @param abs Absolute threshold.
  149.      * @param lineRel Relative threshold for the internal line search optimizer.
  150.      * @param lineAbs Absolute threshold for the internal line search optimizer.
  151.      * @throws MathIllegalArgumentException if {@code abs <= 0}.
  152.      * @throws MathIllegalArgumentException if {@code rel < 2 * FastMath.ulp(1d)}.
  153.      */
  154.     public PowellOptimizer(double rel,
  155.                            double abs,
  156.                            double lineRel,
  157.                            double lineAbs) {
  158.         this(rel, abs, lineRel, lineAbs, null);
  159.     }

  160.     /** {@inheritDoc} */
  161.     @Override
  162.     protected PointValuePair doOptimize() {
  163.         checkParameters();

  164.         final GoalType goal = getGoalType();
  165.         final double[] guess = getStartPoint();
  166.         final int n = guess.length;

  167.         final double[][] direc = new double[n][n];
  168.         for (int i = 0; i < n; i++) {
  169.             direc[i][i] = 1;
  170.         }

  171.         final ConvergenceChecker<PointValuePair> checker
  172.             = getConvergenceChecker();

  173.         double[] x = guess;
  174.         double fVal = computeObjectiveValue(x);
  175.         double[] x1 = x.clone();
  176.         while (true) {
  177.             incrementIterationCount();

  178.             double fX = fVal;
  179.             double delta = 0;
  180.             int bigInd = 0;

  181.             for (int i = 0; i < n; i++) {
  182.                 final double[] d = direc[i].clone();

  183.                 final double fX2 = fVal;

  184.                 final UnivariatePointValuePair optimum = line.search(x, d);
  185.                 fVal = optimum.getValue();
  186.                 final double alphaMin = optimum.getPoint();
  187.                 final double[][] result = newPointAndDirection(x, d, alphaMin);
  188.                 x = result[0];

  189.                 if ((fX2 - fVal) > delta) {
  190.                     delta = fX2 - fVal;
  191.                     bigInd = i;
  192.                 }
  193.             }

  194.             // Default convergence check.
  195.             boolean stop = 2 * (fX - fVal) <=
  196.                 (relativeThreshold * (FastMath.abs(fX) + FastMath.abs(fVal)) +
  197.                  absoluteThreshold);

  198.             final PointValuePair previous = new PointValuePair(x1, fX);
  199.             final PointValuePair current = new PointValuePair(x, fVal);
  200.             if (!stop && checker != null) { // User-defined stopping criteria.
  201.                 stop = checker.converged(getIterations(), previous, current);
  202.             }
  203.             if (stop) {
  204.                 if (goal == GoalType.MINIMIZE) {
  205.                     return (fVal < fX) ? current : previous;
  206.                 } else {
  207.                     return (fVal > fX) ? current : previous;
  208.                 }
  209.             }

  210.             final double[] d = new double[n];
  211.             final double[] x2 = new double[n];
  212.             for (int i = 0; i < n; i++) {
  213.                 d[i] = x[i] - x1[i];
  214.                 x2[i] = 2 * x[i] - x1[i];
  215.             }

  216.             x1 = x.clone();
  217.             final double fX2 = computeObjectiveValue(x2);

  218.             if (fX > fX2) {
  219.                 double t = 2 * (fX + fX2 - 2 * fVal);
  220.                 double temp = fX - fVal - delta;
  221.                 t *= temp * temp;
  222.                 temp = fX - fX2;
  223.                 t -= delta * temp * temp;

  224.                 if (t < 0.0) {
  225.                     final UnivariatePointValuePair optimum = line.search(x, d);
  226.                     fVal = optimum.getValue();
  227.                     final double alphaMin = optimum.getPoint();
  228.                     final double[][] result = newPointAndDirection(x, d, alphaMin);
  229.                     x = result[0];

  230.                     final int lastInd = n - 1;
  231.                     direc[bigInd] = direc[lastInd];
  232.                     direc[lastInd] = result[1];
  233.                 }
  234.             }
  235.         }
  236.     }

  237.     /**
  238.      * Compute a new point (in the original space) and a new direction
  239.      * vector, resulting from the line search.
  240.      *
  241.      * @param p Point used in the line search.
  242.      * @param d Direction used in the line search.
  243.      * @param optimum Optimum found by the line search.
  244.      * @return a 2-element array containing the new point (at index 0) and
  245.      * the new direction (at index 1).
  246.      */
  247.     private double[][] newPointAndDirection(double[] p,
  248.                                             double[] d,
  249.                                             double optimum) {
  250.         final int n = p.length;
  251.         final double[] nP = new double[n];
  252.         final double[] nD = new double[n];
  253.         for (int i = 0; i < n; i++) {
  254.             nD[i] = d[i] * optimum;
  255.             nP[i] = p[i] + nD[i];
  256.         }

  257.         final double[][] result = new double[2][];
  258.         result[0] = nP;
  259.         result[1] = nD;

  260.         return result;
  261.     }

  262.     /**
  263.      * @throws MathRuntimeException if bounds were passed to the
  264.      * {@link #optimize(OptimizationData...)}  optimize} method.
  265.      */
  266.     private void checkParameters() {
  267.         if (getLowerBound() != null ||
  268.             getUpperBound() != null) {
  269.             throw new MathRuntimeException(LocalizedCoreFormats.CONSTRAINT);
  270.         }
  271.     }
  272. }