NelderMeadSimplex.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      https://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. /*
  18.  * This is not the original file distributed by the Apache Software Foundation
  19.  * It has been modified by the Hipparchus project
  20.  */
  21. package org.hipparchus.optim.nonlinear.scalar.noderiv;

  22. import java.util.Comparator;

  23. import org.hipparchus.analysis.MultivariateFunction;
  24. import org.hipparchus.optim.PointValuePair;

  25. /**
  26.  * This class implements the Nelder-Mead simplex algorithm.
  27.  *
  28.  */
  29. public class NelderMeadSimplex extends AbstractSimplex {
  30.     /** Default value for {@link #rho}: {@value}. */
  31.     private static final double DEFAULT_RHO = 1;
  32.     /** Default value for {@link #khi}: {@value}. */
  33.     private static final double DEFAULT_KHI = 2;
  34.     /** Default value for {@link #gamma}: {@value}. */
  35.     private static final double DEFAULT_GAMMA = 0.5;
  36.     /** Default value for {@link #sigma}: {@value}. */
  37.     private static final double DEFAULT_SIGMA = 0.5;
  38.     /** Reflection coefficient. */
  39.     private final double rho;
  40.     /** Expansion coefficient. */
  41.     private final double khi;
  42.     /** Contraction coefficient. */
  43.     private final double gamma;
  44.     /** Shrinkage coefficient. */
  45.     private final double sigma;

  46.     /**
  47.      * Build a Nelder-Mead simplex with default coefficients.
  48.      * The default coefficients are 1.0 for rho, 2.0 for khi and 0.5
  49.      * for both gamma and sigma.
  50.      *
  51.      * @param n Dimension of the simplex.
  52.      */
  53.     public NelderMeadSimplex(final int n) {
  54.         this(n, 1d);
  55.     }

  56.     /**
  57.      * Build a Nelder-Mead simplex with default coefficients.
  58.      * The default coefficients are 1.0 for rho, 2.0 for khi and 0.5
  59.      * for both gamma and sigma.
  60.      *
  61.      * @param n Dimension of the simplex.
  62.      * @param sideLength Length of the sides of the default (hypercube)
  63.      * simplex. See {@link AbstractSimplex#AbstractSimplex(int,double)}.
  64.      */
  65.     public NelderMeadSimplex(final int n, double sideLength) {
  66.         this(n, sideLength,
  67.              DEFAULT_RHO, DEFAULT_KHI, DEFAULT_GAMMA, DEFAULT_SIGMA);
  68.     }

  69.     /**
  70.      * Build a Nelder-Mead simplex with specified coefficients.
  71.      *
  72.      * @param n Dimension of the simplex. See
  73.      * {@link AbstractSimplex#AbstractSimplex(int,double)}.
  74.      * @param sideLength Length of the sides of the default (hypercube)
  75.      * simplex. See {@link AbstractSimplex#AbstractSimplex(int,double)}.
  76.      * @param rho Reflection coefficient.
  77.      * @param khi Expansion coefficient.
  78.      * @param gamma Contraction coefficient.
  79.      * @param sigma Shrinkage coefficient.
  80.      */
  81.     public NelderMeadSimplex(final int n, double sideLength,
  82.                              final double rho, final double khi,
  83.                              final double gamma, final double sigma) {
  84.         super(n, sideLength);

  85.         this.rho = rho;
  86.         this.khi = khi;
  87.         this.gamma = gamma;
  88.         this.sigma = sigma;
  89.     }

  90.     /**
  91.      * Build a Nelder-Mead simplex with specified coefficients.
  92.      *
  93.      * @param n Dimension of the simplex. See
  94.      * {@link AbstractSimplex#AbstractSimplex(int)}.
  95.      * @param rho Reflection coefficient.
  96.      * @param khi Expansion coefficient.
  97.      * @param gamma Contraction coefficient.
  98.      * @param sigma Shrinkage coefficient.
  99.      */
  100.     public NelderMeadSimplex(final int n,
  101.                              final double rho, final double khi,
  102.                              final double gamma, final double sigma) {
  103.         this(n, 1d, rho, khi, gamma, sigma);
  104.     }

  105.     /**
  106.      * Build a Nelder-Mead simplex with default coefficients.
  107.      * The default coefficients are 1.0 for rho, 2.0 for khi and 0.5
  108.      * for both gamma and sigma.
  109.      *
  110.      * @param steps Steps along the canonical axes representing box edges.
  111.      * They may be negative but not zero. See
  112.      */
  113.     public NelderMeadSimplex(final double[] steps) {
  114.         this(steps, DEFAULT_RHO, DEFAULT_KHI, DEFAULT_GAMMA, DEFAULT_SIGMA);
  115.     }

  116.     /**
  117.      * Build a Nelder-Mead simplex with specified coefficients.
  118.      *
  119.      * @param steps Steps along the canonical axes representing box edges.
  120.      * They may be negative but not zero. See
  121.      * {@link AbstractSimplex#AbstractSimplex(double[])}.
  122.      * @param rho Reflection coefficient.
  123.      * @param khi Expansion coefficient.
  124.      * @param gamma Contraction coefficient.
  125.      * @param sigma Shrinkage coefficient.
  126.      * @throws IllegalArgumentException if one of the steps is zero.
  127.      */
  128.     public NelderMeadSimplex(final double[] steps,
  129.                              final double rho, final double khi,
  130.                              final double gamma, final double sigma) {
  131.         super(steps);

  132.         this.rho = rho;
  133.         this.khi = khi;
  134.         this.gamma = gamma;
  135.         this.sigma = sigma;
  136.     }

  137.     /**
  138.      * Build a Nelder-Mead simplex with default coefficients.
  139.      * The default coefficients are 1.0 for rho, 2.0 for khi and 0.5
  140.      * for both gamma and sigma.
  141.      *
  142.      * @param referenceSimplex Reference simplex. See
  143.      * {@link AbstractSimplex#AbstractSimplex(double[][])}.
  144.      */
  145.     public NelderMeadSimplex(final double[][] referenceSimplex) {
  146.         this(referenceSimplex, DEFAULT_RHO, DEFAULT_KHI, DEFAULT_GAMMA, DEFAULT_SIGMA);
  147.     }

  148.     /**
  149.      * Build a Nelder-Mead simplex with specified coefficients.
  150.      *
  151.      * @param referenceSimplex Reference simplex. See
  152.      * {@link AbstractSimplex#AbstractSimplex(double[][])}.
  153.      * @param rho Reflection coefficient.
  154.      * @param khi Expansion coefficient.
  155.      * @param gamma Contraction coefficient.
  156.      * @param sigma Shrinkage coefficient.
  157.      * @throws org.hipparchus.exception.MathIllegalArgumentException
  158.      * if the reference simplex does not contain at least one point.
  159.      * @throws org.hipparchus.exception.MathIllegalArgumentException
  160.      * if there is a dimension mismatch in the reference simplex.
  161.      */
  162.     public NelderMeadSimplex(final double[][] referenceSimplex,
  163.                              final double rho, final double khi,
  164.                              final double gamma, final double sigma) {
  165.         super(referenceSimplex);

  166.         this.rho = rho;
  167.         this.khi = khi;
  168.         this.gamma = gamma;
  169.         this.sigma = sigma;
  170.     }

  171.     /** {@inheritDoc} */
  172.     @Override
  173.     public void iterate(final MultivariateFunction evaluationFunction,
  174.                         final Comparator<PointValuePair> comparator) {
  175.         // The simplex has n + 1 points if dimension is n.
  176.         final int n = getDimension();

  177.         // Interesting values.
  178.         final PointValuePair best = getPoint(0);
  179.         final PointValuePair secondBest = getPoint(n - 1);
  180.         final PointValuePair worst = getPoint(n);
  181.         final double[] xWorst = worst.getPointRef();

  182.         // Compute the centroid of the best vertices (dismissing the worst
  183.         // point at index n).
  184.         final double[] centroid = new double[n];
  185.         for (int i = 0; i < n; i++) {
  186.             final double[] x = getPoint(i).getPointRef();
  187.             for (int j = 0; j < n; j++) {
  188.                 centroid[j] += x[j];
  189.             }
  190.         }
  191.         final double scaling = 1.0 / n;
  192.         for (int j = 0; j < n; j++) {
  193.             centroid[j] *= scaling;
  194.         }

  195.         // compute the reflection point
  196.         final double[] xR = new double[n];
  197.         for (int j = 0; j < n; j++) {
  198.             xR[j] = centroid[j] + rho * (centroid[j] - xWorst[j]);
  199.         }
  200.         final PointValuePair reflected
  201.             = new PointValuePair(xR, evaluationFunction.value(xR), false);

  202.         if (comparator.compare(best, reflected) <= 0 &&
  203.             comparator.compare(reflected, secondBest) < 0) {
  204.             // Accept the reflected point.
  205.             replaceWorstPoint(reflected, comparator);
  206.         } else if (comparator.compare(reflected, best) < 0) {
  207.             // Compute the expansion point.
  208.             final double[] xE = new double[n];
  209.             for (int j = 0; j < n; j++) {
  210.                 xE[j] = centroid[j] + khi * (xR[j] - centroid[j]);
  211.             }
  212.             final PointValuePair expanded
  213.                 = new PointValuePair(xE, evaluationFunction.value(xE), false);

  214.             if (comparator.compare(expanded, reflected) < 0) {
  215.                 // Accept the expansion point.
  216.                 replaceWorstPoint(expanded, comparator);
  217.             } else {
  218.                 // Accept the reflected point.
  219.                 replaceWorstPoint(reflected, comparator);
  220.             }
  221.         } else {
  222.             if (comparator.compare(reflected, worst) < 0) {
  223.                 // Perform an outside contraction.
  224.                 final double[] xC = new double[n];
  225.                 for (int j = 0; j < n; j++) {
  226.                     xC[j] = centroid[j] + gamma * (xR[j] - centroid[j]);
  227.                 }
  228.                 final PointValuePair outContracted
  229.                     = new PointValuePair(xC, evaluationFunction.value(xC), false);
  230.                 if (comparator.compare(outContracted, reflected) <= 0) {
  231.                     // Accept the contraction point.
  232.                     replaceWorstPoint(outContracted, comparator);
  233.                     return;
  234.                 }
  235.             } else {
  236.                 // Perform an inside contraction.
  237.                 final double[] xC = new double[n];
  238.                 for (int j = 0; j < n; j++) {
  239.                     xC[j] = centroid[j] - gamma * (centroid[j] - xWorst[j]);
  240.                 }
  241.                 final PointValuePair inContracted
  242.                     = new PointValuePair(xC, evaluationFunction.value(xC), false);

  243.                 if (comparator.compare(inContracted, worst) < 0) {
  244.                     // Accept the contraction point.
  245.                     replaceWorstPoint(inContracted, comparator);
  246.                     return;
  247.                 }
  248.             }

  249.             // Perform a shrink.
  250.             final double[] xSmallest = getPoint(0).getPointRef();
  251.             for (int i = 1; i <= n; i++) {
  252.                 final double[] x = getPoint(i).getPoint();
  253.                 for (int j = 0; j < n; j++) {
  254.                     x[j] = xSmallest[j] + sigma * (x[j] - xSmallest[j]);
  255.                 }
  256.                 setPoint(i, new PointValuePair(x, Double.NaN, false));
  257.             }
  258.             evaluate(evaluationFunction, comparator);
  259.         }
  260.     }
  261. }