ADMMQPOptimizer.java

  1. /*
  2.  * Licensed to the Hipparchus project under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The Hipparchus project licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      https://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.hipparchus.optim.nonlinear.vector.constrained;

  18. import java.util.ArrayList;
  19. import java.util.List;

  20. import org.hipparchus.exception.LocalizedCoreFormats;
  21. import org.hipparchus.exception.MathIllegalArgumentException;
  22. import org.hipparchus.linear.Array2DRowRealMatrix;
  23. import org.hipparchus.linear.ArrayRealVector;
  24. import org.hipparchus.linear.RealMatrix;
  25. import org.hipparchus.linear.RealVector;
  26. import org.hipparchus.optim.ConvergenceChecker;
  27. import org.hipparchus.optim.LocalizedOptimFormats;
  28. import org.hipparchus.optim.OptimizationData;
  29. import org.hipparchus.optim.nonlinear.scalar.ObjectiveFunction;
  30. import org.hipparchus.util.FastMath;
  31. import org.hipparchus.util.MathUtils;

  32. /**
  33.  * Alternating Direction Method of Multipliers Quadratic Programming Optimizer.
  34.  * \[
  35.  *  min \frac{1}{2} X^T Q X + G X a\\
  36.  *  A  X    = B_1\\
  37.  *  B  X    \ge B_2\\
  38.  *  l_b \le C X \le u_b
  39.  * \]
  40.  * Algorithm based on paper:"An Operator Splitting Solver for Quadratic Programs(Bartolomeo Stellato, Goran Banjac, Paul Goulart, Alberto Bemporad, Stephen Boyd,February 13 2020)"
  41.  * @since 3.1
  42.  */

  43. public class ADMMQPOptimizer extends QPOptimizer {

  44.     /** Algorithm settings. */
  45.     private ADMMQPOption settings;

  46.     /** Equality constraint (may be null). */
  47.     private LinearEqualityConstraint eqConstraint;

  48.     /** Inequality constraint (may be null). */
  49.     private LinearInequalityConstraint iqConstraint;

  50.     /** Boundary constraint (may be null). */
  51.     private LinearBoundedConstraint bqConstraint;

  52.     /** Objective function. */
  53.     private QuadraticFunction function;

  54.     /** Problem solver. */
  55.     private final ADMMQPKKT solver;

  56.     /** Problem convergence checker. */
  57.     private ADMMQPConvergenceChecker checker;

  58.     /** Convergence indicator. */
  59.     private boolean converged;

  60.     /** Current step size. */
  61.     private double rho;

  62.     /** Simple constructor.
  63.      * <p>
  64.      * This constructor sets all {@link ADMMQPOption options} to their default values
  65.      * </p>
  66.      */
  67.     public ADMMQPOptimizer() {
  68.         settings   = new ADMMQPOption();
  69.         solver     = new ADMMQPKKT();
  70.         converged  = false;
  71.         rho        = 0.1;
  72.     }

  73.     /** {@inheritDoc} */
  74.     @Override
  75.     public ConvergenceChecker<LagrangeSolution> getConvergenceChecker() {
  76.         return checker;
  77.     }

  78.     /** {@inheritDoc} */
  79.     @Override
  80.     public LagrangeSolution optimize(OptimizationData... optData) {
  81.         return super.optimize(optData);
  82.     }

  83.     /** {@inheritDoc} */
  84.     @Override
  85.     protected void parseOptimizationData(OptimizationData... optData) {
  86.         super.parseOptimizationData(optData);
  87.         for (OptimizationData data: optData) {

  88.              if (data instanceof ObjectiveFunction) {
  89.                 function = (QuadraticFunction) ((ObjectiveFunction) data).getObjectiveFunction();
  90.                 continue;
  91.             }

  92.             if (data instanceof LinearEqualityConstraint) {
  93.                 eqConstraint = (LinearEqualityConstraint) data;
  94.                 continue;
  95.             }
  96.             if (data instanceof LinearInequalityConstraint) {
  97.                 iqConstraint = (LinearInequalityConstraint) data;
  98.                 continue;
  99.             }

  100.             if (data instanceof LinearBoundedConstraint) {
  101.                 bqConstraint = (LinearBoundedConstraint) data;
  102.                 continue;
  103.             }

  104.             if (data instanceof ADMMQPOption) {
  105.                 settings = (ADMMQPOption) data;
  106.             }

  107.         }
  108.         // if we got here, convexObjective exists
  109.         int n = function.dim();
  110.         if (eqConstraint != null) {
  111.             int nDual = eqConstraint.dimY();
  112.             if (nDual >= n) {
  113.                 throw new MathIllegalArgumentException(LocalizedOptimFormats.CONSTRAINTS_RANK, nDual, n);
  114.             }
  115.             int nTest = eqConstraint.getA().getColumnDimension();
  116.             if (nDual == 0) {
  117.                 throw new MathIllegalArgumentException(LocalizedCoreFormats.ZERO_NOT_ALLOWED);
  118.             }
  119.             MathUtils.checkDimension(nTest, n);
  120.         }

  121.     }

  122.     /** {@inheritDoc} */
  123.     @Override
  124.     public LagrangeSolution doOptimize() {
  125.         final int n = function.dim();
  126.         int me = 0;
  127.         int mi = 0;
  128.         int mb = 0;
  129.         int rhoUpdateCount = 0;

  130.         //PHASE 1 First Solution


  131.        //QUADRATIC TERM
  132.         RealMatrix H = function.getP();
  133.        //GRADIENT
  134.         RealVector q = function.getQ();


  135.        //EQUALITY CONSTRAINT
  136.         if (eqConstraint != null) {
  137.             me = eqConstraint.dimY();
  138.         }
  139.        //INEQUALITY CONSTRAINT
  140.         if (iqConstraint != null) {
  141.             mi = iqConstraint.dimY();
  142.         }
  143.         //BOUNDED CONSTRAINT
  144.         if (bqConstraint != null) {
  145.             mb = bqConstraint.dimY();
  146.         }

  147.         RealVector lb = new ArrayRealVector(me + mi + mb);
  148.         RealVector ub = new ArrayRealVector(me + mi + mb);

  149.         //COMPOSE A MATRIX AND LOWER AND UPPER BOUND
  150.         RealMatrix A = new Array2DRowRealMatrix(me + mi + mb, n);
  151.         if (eqConstraint != null) {
  152.             A.setSubMatrix(eqConstraint.jacobian(null).getData(), 0, 0);
  153.             lb.setSubVector(0,eqConstraint.getLowerBound());
  154.             ub.setSubVector(0,eqConstraint.getUpperBound());
  155.         }
  156.         if (iqConstraint != null) {
  157.             A.setSubMatrix(iqConstraint.jacobian(null).getData(), me, 0);
  158.             ub.setSubVector(me,iqConstraint.getUpperBound());
  159.             lb.setSubVector(me,iqConstraint.getLowerBound());
  160.         }

  161.         if (mb > 0) {
  162.             A.setSubMatrix(bqConstraint.jacobian(null).getData(), me + mi, 0);
  163.             ub.setSubVector(me + mi,bqConstraint.getUpperBound());
  164.             lb.setSubVector(me + mi,bqConstraint.getLowerBound());
  165.         }

  166.         checker = new ADMMQPConvergenceChecker(H, A, q, settings.getEps(), settings.getEps());

  167.         //SETUP WORKING MATRIX
  168.         RealMatrix Hw = H.copy();
  169.         RealMatrix Aw = A.copy();
  170.         RealVector qw = q.copy();
  171.         RealVector ubw = ub.copy();
  172.         RealVector lbw = lb.copy();
  173.         RealVector x;
  174.         if (getStartPoint() != null) {
  175.             x = new ArrayRealVector(getStartPoint());
  176.         } else {
  177.             x = new ArrayRealVector(function.dim());
  178.         }

  179.         ADMMQPModifiedRuizEquilibrium dec = new ADMMQPModifiedRuizEquilibrium(H, A,q);

  180.         if (settings.isScaling()) {
  181.            //
  182.             dec.normalize(settings.getEps(), settings.getScaleMaxIteration());
  183.             Hw = dec.getScaledH();
  184.             Aw = dec.getScaledA();
  185.             qw = dec.getScaledQ();
  186.             lbw = dec.getScaledLUb(lb);
  187.             ubw = dec.getScaledLUb(ub);

  188.             x = dec.scaleX(x.copy());

  189.         }

  190.         final ADMMQPConvergenceChecker checkerRho = new ADMMQPConvergenceChecker(Hw, Aw, qw, settings.getEps(), settings.getEps());
  191.         //SETUP VECTOR SOLUTION

  192.         RealVector z = Aw.operate(x);
  193.         RealVector y = new ArrayRealVector(me + mi + mb);

  194.         solver.initialize(Hw, Aw, qw, me, lbw, ubw, rho, settings.getSigma(), settings.getAlpha());
  195.         RealVector xstar = null;
  196.         RealVector ystar = null;
  197.         RealVector zstar;

  198.         while (iterations.getCount() <= iterations.getMaximalCount()) {
  199.             ADMMQPSolution sol = solver.iterate(x, y, z);
  200.             x = sol.getX();
  201.             y = sol.getLambda();
  202.             z = sol.getZ();
  203.             //new ArrayRealVector(me + mi + mb);
  204.             if (rhoUpdateCount < settings.getMaxRhoIteration()) {
  205.                 double rp       = checkerRho.residualPrime(x, z);
  206.                 double rd       = checkerRho.residualDual(x, y);
  207.                 double maxP     = checkerRho.maxPrimal(x, z);
  208.                 double maxD     = checkerRho.maxDual(x, y);
  209.                 boolean updated = manageRho(me, rp, rd, maxP, maxD);

  210.                 if (updated) {
  211.                     ++rhoUpdateCount;
  212.                 }
  213.             }


  214.             if (settings.isScaling()) {

  215.                 xstar = dec.unscaleX(x);
  216.                 ystar = dec.unscaleY(y);
  217.                 zstar = dec.unscaleZ(z);

  218.             } else {

  219.                 xstar = x.copy();
  220.                 ystar = y.copy();
  221.                 zstar = z.copy();

  222.             }

  223.             double rp        = checker.residualPrime(xstar, zstar);
  224.             double rd        = checker.residualDual(xstar, ystar);
  225.             double maxPrimal = checker.maxPrimal(xstar, zstar);
  226.             double maxDual   = checker.maxDual(xstar, ystar);

  227.             if (checker.converged(rp, rd, maxPrimal, maxDual)) {
  228.                 converged = true;
  229.                 break;
  230.             }
  231.             iterations.increment();

  232.         }

  233.         //SOLUTION POLISHING
  234.         if (settings.isPolishing()) {
  235.             ADMMQPSolution finalSol = polish(Hw, Aw, qw, lbw, ubw, x, y, z);
  236.             if (settings.isScaling()) {
  237.                 xstar = dec.unscaleX(finalSol.getX());
  238.                 ystar = dec.unscaleY(finalSol.getLambda());
  239.             } else {
  240.                 xstar = finalSol.getX();
  241.                 ystar = finalSol.getLambda();
  242.             }
  243.         }
  244.         for (int i = 0; i < me + mi; i++) {
  245.             ystar.setEntry(i, -ystar.getEntry(i));
  246.         }

  247.         return new LagrangeSolution(xstar, ystar, function.value(xstar));

  248.     }

  249.     /** Check if convergence has been reached.
  250.      * @return true if convergence has been reached
  251.      */
  252.     public boolean isConverged() {
  253.         return converged;
  254.     }

  255.     /** Polish solution.
  256.      * @param H quadratic term matrix
  257.      * @param A constraint coefficients matrix
  258.      * @param q linear term matrix
  259.      * @param lb lower bound
  260.      * @param ub upper bound
  261.      * @param x primal problem solution
  262.      * @param y dual problem solution
  263.      * @param z auxiliary variable
  264.      * @return polished solution
  265.      */
  266.     private ADMMQPSolution polish(RealMatrix H, RealMatrix A, RealVector q, RealVector lb, RealVector ub,
  267.                                   RealVector x, RealVector y, RealVector z) {

  268.         List<double[]> Aentry    = new ArrayList<>();
  269.         List<Double>  lubEntry   = new ArrayList<>();
  270.         List<Double>  yEntry     = new ArrayList<>();

  271.         // FIND ACTIVE ON LOWER BAND
  272.         for (int j = 0; j < A.getRowDimension(); j++) {
  273.             if (z.getEntry(j) - lb.getEntry(j) < -y.getEntry(j)) {  // lower-active

  274.                 Aentry.add(A.getRow(j));
  275.                 lubEntry.add(lb.getEntry(j));
  276.                 yEntry.add(y.getEntry(j));

  277.             }
  278.         }
  279.         //FIND ACTIVE ON UPPER BAND
  280.         for (int j = 0; j < A.getRowDimension(); j++) {
  281.             if (-z.getEntry(j) + ub.getEntry(j) < y.getEntry(j)) { // lower-active

  282.                 Aentry.add(A.getRow(j));
  283.                 lubEntry.add(ub.getEntry(j));
  284.                 yEntry.add(y.getEntry(j));

  285.             }

  286.         }
  287.         RealMatrix Aactive;
  288.         RealVector lub;

  289.         RealVector ystar;
  290.         RealVector xstar = x.copy();
  291.         //!Aentry.isEmpty()
  292.         if (!Aentry.isEmpty()) {

  293.             Aactive = new Array2DRowRealMatrix(Aentry.toArray(new double[0][]));
  294.             lub = new ArrayRealVector(lubEntry.toArray(new Double[0]));
  295.             ystar = new ArrayRealVector(yEntry.toArray(new Double[0]));
  296.             solver.initialize(H, Aactive, q, 0, lub, lub,
  297.                               settings.getSigma(), settings.getSigma(), settings.getAlpha());

  298.             for (int i = 0; i < settings.getPolishIteration(); i++) {
  299.                 RealVector kttx = (H.operate(xstar)).add(Aactive.transpose().operate(ystar));
  300.                 RealVector ktty = Aactive.operate(xstar);
  301.                 RealVector b1 = q.mapMultiply(-1.0).subtract(kttx);
  302.                 RealVector b2 = lub.mapMultiply(1.0).subtract(ktty);
  303.                 ADMMQPSolution dxz = solver.solve(b1,b2);
  304.                 xstar = xstar.add(dxz.getX());
  305.                 ystar = ystar.add(dxz.getV());
  306.             }

  307.             return new ADMMQPSolution(xstar, null, y, A.operate(xstar));

  308.         } else {
  309.             return new ADMMQPSolution(x, null, y, z);
  310.         }
  311.     }

  312.     /** Manage step size.
  313.      * @param me number of equality constraints
  314.      * @param rp primal residual
  315.      * @param rd dual residual
  316.      * @param maxPrimal primal vectors max
  317.      * @param maxDual dual vectors max
  318.      * @return true if rho has been updated
  319.      */
  320.     private boolean manageRho(int me, double rp, double rd, double maxPrimal, double maxDual) {
  321.         boolean updated = false;
  322.         if (settings.updateRho()) {

  323.             // estimate new step size
  324.             double rhonew = FastMath.min(FastMath.max(rho * FastMath.sqrt((rp * maxDual) / (rd * maxPrimal)),
  325.                                                       settings.getRhoMin()),
  326.                                          settings.getRhoMax());

  327.             if ((rhonew > rho * 5.0) || (rhonew < rho / 5.0)) {

  328.                 rho = rhonew;
  329.                 updated = true;

  330.                 solver.updateSigmaRho(settings.getSigma(), me, rho);
  331.             }
  332.         }
  333.         return updated;
  334.     }

  335. }