BaseMultiStartMultivariateOptimizer.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      https://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. /*
  18.  * This is not the original file distributed by the Apache Software Foundation
  19.  * It has been modified by the Hipparchus project
  20.  */
  21. package org.hipparchus.optim;

  22. import org.hipparchus.exception.LocalizedCoreFormats;
  23. import org.hipparchus.exception.MathIllegalArgumentException;
  24. import org.hipparchus.exception.MathIllegalStateException;
  25. import org.hipparchus.random.RandomVectorGenerator;

  26. /**
  27.  * Base class multi-start optimizer for a multivariate function.
  28.  * <br>
  29.  * This class wraps an optimizer in order to use it several times in
  30.  * turn with different starting points (trying to avoid being trapped
  31.  * in a local extremum when looking for a global one).
  32.  * <em>It is not a "user" class.</em>
  33.  *
  34.  * @param <P> Type of the point/value pair returned by the optimization
  35.  * algorithm.
  36.  *
  37.  */
  38. public abstract class BaseMultiStartMultivariateOptimizer<P>
  39.     extends BaseMultivariateOptimizer<P> {
  40.     /** Underlying classical optimizer. */
  41.     private final BaseMultivariateOptimizer<P> optimizer;
  42.     /** Number of evaluations already performed for all starts. */
  43.     private int totalEvaluations;
  44.     /** Number of starts to go. */
  45.     private int starts;
  46.     /** Random generator for multi-start. */
  47.     private RandomVectorGenerator generator;
  48.     /** Optimization data. */
  49.     private OptimizationData[] optimData;
  50.     /**
  51.      * Location in {@link #optimData} where the updated maximum
  52.      * number of evaluations will be stored.
  53.      */
  54.     private int maxEvalIndex = -1;
  55.     /**
  56.      * Location in {@link #optimData} where the updated start value
  57.      * will be stored.
  58.      */
  59.     private int initialGuessIndex = -1;

  60.     /**
  61.      * Create a multi-start optimizer from a single-start optimizer.
  62.      * <p>
  63.      * Note that if there are bounds constraints (see {@link #getLowerBound()}
  64.      * and {@link #getUpperBound()}), then a simple rejection algorithm is used
  65.      * at each restart. This implies that the random vector generator should have
  66.      * a good probability to generate vectors in the bounded domain, otherwise the
  67.      * rejection algorithm will hit the {@link #getMaxEvaluations()} count without
  68.      * generating a proper restart point. Users must be take great care of the <a
  69.      * href="http://en.wikipedia.org/wiki/Curse_of_dimensionality">curse of dimensionality</a>.
  70.      * </p>
  71.      * @param optimizer Single-start optimizer to wrap.
  72.      * @param starts Number of starts to perform. If {@code starts == 1},
  73.      * the {@link #optimize(OptimizationData[]) optimize} will return the
  74.      * same solution as the given {@code optimizer} would return.
  75.      * @param generator Random vector generator to use for restarts.
  76.      * @throws MathIllegalArgumentException if {@code starts < 1}.
  77.      */
  78.     protected BaseMultiStartMultivariateOptimizer(final BaseMultivariateOptimizer<P> optimizer, final int starts,
  79.                                                   final RandomVectorGenerator generator) {
  80.         super(optimizer.getConvergenceChecker());

  81.         if (starts < 1) {
  82.             throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL,
  83.                                                    starts, 1);
  84.         }

  85.         this.optimizer = optimizer;
  86.         this.starts = starts;
  87.         this.generator = generator;
  88.     }

  89.     /** {@inheritDoc} */
  90.     @Override
  91.     public int getEvaluations() {
  92.         return totalEvaluations;
  93.     }

  94.     /**
  95.      * Gets all the optima found during the last call to {@code optimize}.
  96.      * The optimizer stores all the optima found during a set of
  97.      * restarts. The {@code optimize} method returns the best point only.
  98.      * This method returns all the points found at the end of each starts,
  99.      * including the best one already returned by the {@code optimize} method.
  100.      * <br>
  101.      * The returned array as one element for each start as specified
  102.      * in the constructor. It is ordered with the results from the
  103.      * runs that did converge first, sorted from best to worst
  104.      * objective value (i.e in ascending order if minimizing and in
  105.      * descending order if maximizing), followed by {@code null} elements
  106.      * corresponding to the runs that did not converge. This means all
  107.      * elements will be {@code null} if the {@code optimize} method did throw
  108.      * an exception.
  109.      * This also means that if the first element is not {@code null}, it is
  110.      * the best point found across all starts.
  111.      * <br>
  112.      * The behaviour is undefined if this method is called before
  113.      * {@code optimize}; it will likely throw {@code NullPointerException}.
  114.      *
  115.      * @return an array containing the optima sorted from best to worst.
  116.      */
  117.     public abstract P[] getOptima();

  118.     /**
  119.      * {@inheritDoc}
  120.      *
  121.      * @throws MathIllegalStateException if {@code optData} does not contain an
  122.      * instance of {@link MaxEval} or {@link InitialGuess}.
  123.      */
  124.     @Override
  125.     public P optimize(OptimizationData... optData) {
  126.         // Store arguments in order to pass them to the internal optimizer.
  127.        optimData = optData.clone();
  128.         // Set up base class and perform computations.
  129.         return super.optimize(optData);
  130.     }

  131.     /** {@inheritDoc} */
  132.     @Override
  133.     protected P doOptimize() {
  134.         // Remove all instances of "MaxEval" and "InitialGuess" from the
  135.         // array that will be passed to the internal optimizer.
  136.         // The former is to enforce smaller numbers of allowed evaluations
  137.         // (according to how many have been used up already), and the latter
  138.         // to impose a different start value for each start.
  139.         for (int i = 0; i < optimData.length; i++) {
  140.             if (optimData[i] instanceof MaxEval) {
  141.                 optimData[i] = null;
  142.                 maxEvalIndex = i;
  143.             }
  144.             if (optimData[i] instanceof InitialGuess) {
  145.                 optimData[i] = null;
  146.                 initialGuessIndex = i;
  147.                 continue;
  148.             }
  149.         }
  150.         if (maxEvalIndex == -1) {
  151.             throw new MathIllegalStateException(LocalizedCoreFormats.ILLEGAL_STATE);
  152.         }
  153.         if (initialGuessIndex == -1) {
  154.             throw new MathIllegalStateException(LocalizedCoreFormats.ILLEGAL_STATE);
  155.         }

  156.         RuntimeException lastException = null;
  157.         totalEvaluations = 0;
  158.         clear();

  159.         final int maxEval = getMaxEvaluations();
  160.         final double[] min = getLowerBound();
  161.         final double[] max = getUpperBound();
  162.         final double[] startPoint = getStartPoint();

  163.         // Multi-start loop.
  164.         for (int i = 0; i < starts; i++) {
  165.             // CHECKSTYLE: stop IllegalCatch
  166.             try {
  167.                 // Decrease number of allowed evaluations.
  168.                 optimData[maxEvalIndex] = new MaxEval(maxEval - totalEvaluations);
  169.                 // New start value.
  170.                 double[] s = null;
  171.                 if (i == 0) {
  172.                     s = startPoint;
  173.                 } else {
  174.                     int attempts = 0;
  175.                     while (s == null) {
  176.                         if (attempts >= getMaxEvaluations()) {
  177.                             throw new MathIllegalStateException(LocalizedCoreFormats.MAX_COUNT_EXCEEDED,
  178.                                                                 getMaxEvaluations());
  179.                         }
  180.                         s = generator.nextVector();
  181.                         for (int k = 0; s != null && k < s.length; ++k) {
  182.                             if ((min != null && s[k] < min[k]) || (max != null && s[k] > max[k])) {
  183.                                 // reject the vector
  184.                                 s = null;
  185.                             }
  186.                         }
  187.                         ++attempts;
  188.                     }
  189.                 }
  190.                 optimData[initialGuessIndex] = new InitialGuess(s);
  191.                 // Optimize.
  192.                 final P result = optimizer.optimize(optimData);
  193.                 store(result);
  194.             } catch (RuntimeException mue) { // NOPMD - caching a RuntimeException is intentional here, it will be rethrown later
  195.                 lastException = mue;
  196.             }
  197.             // CHECKSTYLE: resume IllegalCatch

  198.             totalEvaluations += optimizer.getEvaluations();
  199.         }

  200.         final P[] optima = getOptima();
  201.         if (optima.length == 0) {
  202.             // All runs failed.
  203.             throw lastException; // Cannot be null if starts >= 1.
  204.         }

  205.         // Return the best optimum.
  206.         return optima[0];
  207.     }

  208.     /**
  209.      * Method that will be called in order to store each found optimum.
  210.      *
  211.      * @param optimum Result of an optimization run.
  212.      */
  213.     protected abstract void store(P optimum);
  214.     /**
  215.      * Method that will called in order to clear all stored optima.
  216.      */
  217.     protected abstract void clear();
  218. }