ADMMQPKKT.java

  1. /*
  2.  * Licensed to the Hipparchus project under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The Hipparchus project licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      https://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.hipparchus.optim.nonlinear.vector.constrained;


  18. import org.hipparchus.linear.ArrayRealVector;
  19. import org.hipparchus.linear.DecompositionSolver;
  20. import org.hipparchus.linear.EigenDecompositionSymmetric;
  21. import org.hipparchus.linear.MatrixUtils;
  22. import org.hipparchus.linear.RealMatrix;
  23. import org.hipparchus.linear.RealVector;
  24. import org.hipparchus.util.FastMath;

  25. /** Alternative Direction Method of Multipliers Solver.
  26.  * @since 3.1
  27.  */
  28. public class ADMMQPKKT implements KarushKuhnTuckerSolver<ADMMQPSolution> {

  29.     /** Square matrix of weights for quadratic terms. */
  30.     private RealMatrix H;

  31.     /** Vector of weights for linear terms. */
  32.     private RealVector q;

  33.     /** Constraints coefficients matrix. */
  34.     private RealMatrix A;

  35.     /** Regularization term sigma for Karush–Kuhn–Tucker solver. */
  36.     private double sigma;

  37.     /** TBC. */
  38.     private RealMatrix R;

  39.     /** Inverse of R. */
  40.     private RealMatrix Rinv;

  41.     /** Lower bound. */
  42.     private RealVector lb;

  43.     /** Upper bound. */
  44.     private RealVector ub;

  45.     /** Alpha filter for ADMM iteration. */
  46.     private double alpha;

  47.     /** Constrained problem KKT matrix. */
  48.     private RealMatrix M; // NOPMD

  49.     /** Solver for M. */
  50.     private DecompositionSolver dsX;

  51.     /** Simple constructor.
  52.      * <p>
  53.      * BEWARE, nothing is initialized here, it is {@link #initialize(RealMatrix, RealMatrix,
  54.      * RealVector, int, RealVector, RealVector, double, double, double) initialize} <em>must</em>
  55.      * be called before using the instance.
  56.      * </p>
  57.      */
  58.     ADMMQPKKT() {
  59.         // nothing initialized yet!
  60.     }

  61.     /** {@inheritDoc} */
  62.     @Override
  63.     public ADMMQPSolution solve(RealVector b1, final RealVector b2) {
  64.         RealVector z = dsX.solve(new ArrayRealVector((ArrayRealVector) b1,b2));
  65.         return new ADMMQPSolution(z.getSubVector(0,b1.getDimension()), z.getSubVector(b1.getDimension(), b2.getDimension()));
  66.     }

  67.     /** Update steps
  68.      * @param newSigma new regularization term sigma for Karush–Kuhn–Tucker solver
  69.      * @param me number of equality constraints
  70.      * @param rho new step size
  71.      */
  72.     public void updateSigmaRho(double newSigma, int me, double rho) {
  73.         this.sigma = newSigma;
  74.         this.H = H.add(MatrixUtils.createRealIdentityMatrix(H.getColumnDimension()).scalarMultiply(newSigma));
  75.         createPenaltyMatrix(me, rho);
  76.         M =  MatrixUtils.createRealMatrix(H.getRowDimension() + A.getRowDimension(),
  77.                                           H.getRowDimension() + A.getRowDimension());
  78.         M.setSubMatrix(H.getData(), 0,0);
  79.         M.setSubMatrix(A.getData(), H.getRowDimension(),0);
  80.         M.setSubMatrix(A.transpose().getData(), 0, H.getRowDimension());
  81.         M.setSubMatrix(Rinv.scalarMultiply(-1.0).getData(), H.getRowDimension(),H.getRowDimension());
  82.         dsX = new EigenDecompositionSymmetric(M).getSolver();
  83.     }

  84.     /** Initialize problem
  85.      * @param newH square matrix of weights for quadratic term
  86.      * @param newA constraints coefficients matrix
  87.      * @param newQ TBD
  88.      * @param me number of equality constraints
  89.      * @param newLb lower bound
  90.      * @param newUb upper bound
  91.      * @param rho step size
  92.      * @param newSigma regularization term sigma for Karush–Kuhn–Tucker solver
  93.      * @param newAlpha alpha filter for ADMM iteration
  94.      */
  95.     public void initialize(RealMatrix newH, RealMatrix newA, RealVector newQ,
  96.                            int me, RealVector newLb, RealVector newUb,
  97.                            double rho, double newSigma, double newAlpha) {
  98.         this.lb = newLb;
  99.         this.ub = newUb;
  100.         this.alpha = newAlpha;
  101.         this.sigma = newSigma;
  102.         this.H = newH.add(MatrixUtils.createRealIdentityMatrix(newH.getColumnDimension()).scalarMultiply(newSigma));
  103.         this.A = newA.copy();
  104.         this.q = newQ.copy();
  105.         createPenaltyMatrix(me, rho);

  106.         M =  MatrixUtils.createRealMatrix(newH.getRowDimension() + newA.getRowDimension(),
  107.                                           newH.getRowDimension() + newA.getRowDimension());
  108.         M.setSubMatrix(newH.getData(),0,0);
  109.         M.setSubMatrix(newA.getData(),newH.getRowDimension(),0);
  110.         M.setSubMatrix(newA.transpose().getData(),0,newH.getRowDimension());
  111.         M.setSubMatrix(Rinv.scalarMultiply(-1.0).getData(),newH.getRowDimension(),newH.getRowDimension());
  112.         dsX = new EigenDecompositionSymmetric(M).getSolver();
  113.     }

  114.     private void createPenaltyMatrix(int me, double rho) {
  115.         this.R = MatrixUtils.createRealIdentityMatrix(A.getRowDimension());

  116.         for (int i = 0; i < R.getRowDimension(); i++) {
  117.             if (i < me) {
  118.                 R.setEntry(i, i, rho * 1000.0);

  119.             } else {
  120.                 R.setEntry(i, i, rho);

  121.             }
  122.         }
  123.         this.Rinv = MatrixUtils.inverse(R);
  124.     }

  125.     /** {@inheritDoc} */
  126.     @Override
  127.     public ADMMQPSolution iterate(RealVector... previousSol) {
  128.         double onealfa = 1.0 - alpha;
  129.         //SAVE OLD VALUE
  130.         RealVector xold = previousSol[0].copy();
  131.         RealVector yold = previousSol[1].copy();
  132.         RealVector zold = previousSol[2].copy();

  133.         //UPDATE RIGHT VECTOR
  134.         RealVector b1 = previousSol[0].mapMultiply(sigma).subtract(q);
  135.         RealVector b2 = previousSol[2].subtract(Rinv.operate(previousSol[1]));

  136.         //SOLVE KKT SYSYEM
  137.         ADMMQPSolution sol = solve(b1, b2);
  138.         RealVector xtilde = sol.getX();
  139.         RealVector vtilde = sol.getV();

  140.         //UPDATE ZTILDE
  141.         RealVector ztilde = zold.add(Rinv.operate(vtilde.subtract(yold)));
  142.         //UPDATE X
  143.         previousSol[0] = xtilde.mapMultiply(alpha).add(xold.mapMultiply(onealfa));

  144.         //UPDATE Z PARTIAL
  145.         RealVector zpartial = ztilde.mapMultiply(alpha).add(zold.mapMultiply(onealfa)).add(Rinv.operate(yold));

  146.         //PROJECT ZPARTIAL AND UPDATE Z
  147.         for (int j = 0; j < previousSol[2].getDimension(); j++) {
  148.             previousSol[2].setEntry(j, FastMath.min(FastMath.max(zpartial.getEntry(j), lb.getEntry(j)), ub.getEntry(j)));
  149.         }

  150.         //UPDATE Y
  151.         RealVector ytilde = ztilde.mapMultiply(alpha).add(zold.mapMultiply(onealfa).subtract(previousSol[2]));
  152.         previousSol[1] = yold.add(R.operate(ytilde));

  153.         return new ADMMQPSolution(previousSol[0], vtilde, previousSol[1], previousSol[2]);
  154.     }

  155. }