View Javadoc
1   /*
2    * Licensed to the Hipparchus project under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The Hipparchus project licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.hipparchus.optim.nonlinear.vector.constrained;
18  
19  import java.util.ArrayList;
20  import java.util.List;
21  
22  import org.hipparchus.exception.LocalizedCoreFormats;
23  import org.hipparchus.exception.MathIllegalArgumentException;
24  import org.hipparchus.linear.Array2DRowRealMatrix;
25  import org.hipparchus.linear.ArrayRealVector;
26  import org.hipparchus.linear.RealMatrix;
27  import org.hipparchus.linear.RealVector;
28  import org.hipparchus.optim.ConvergenceChecker;
29  import org.hipparchus.optim.LocalizedOptimFormats;
30  import org.hipparchus.optim.OptimizationData;
31  import org.hipparchus.optim.nonlinear.scalar.ObjectiveFunction;
32  import org.hipparchus.util.FastMath;
33  import org.hipparchus.util.MathUtils;
34  
35  /**
36   * Alternating Direction Method of Multipliers Quadratic Programming Optimizer.
37   * \[
38   *  min \frac{1}{2} X^T Q X + G X a\\
39   *  A  X    = B_1\\
40   *  B  X    \ge B_2\\
41   *  l_b \le C X \le u_b
42   * \]
43   * Algorithm based on paper:"An Operator Splitting Solver for Quadratic Programs(Bartolomeo Stellato, Goran Banjac, Paul Goulart, Alberto Bemporad, Stephen Boyd,February 13 2020)"
44   * @since 3.1
45   */
46  
47  public class ADMMQPOptimizer extends QPOptimizer {
48  
49      /** Algorithm settings. */
50      private ADMMQPOption settings;
51  
52      /** Equality constraint (may be null). */
53      private LinearEqualityConstraint eqConstraint;
54  
55      /** Inequality constraint (may be null). */
56      private LinearInequalityConstraint iqConstraint;
57  
58      /** Boundary constraint (may be null). */
59      private LinearBoundedConstraint bqConstraint;
60  
61      /** Objective function. */
62      private QuadraticFunction function;
63  
64      /** Problem solver. */
65      private final ADMMQPKKT solver;
66  
67      /** Problem convergence checker. */
68      private ADMMQPConvergenceChecker checker;
69  
70      /** Convergence indicator. */
71      private boolean converged;
72  
73      /** Current step size. */
74      private double rho;
75  
76      /** Simple constructor.
77       * <p>
78       * This constructor sets all {@link ADMMQPOption options} to their default values
79       * </p>
80       */
81      public ADMMQPOptimizer() {
82          settings   = new ADMMQPOption();
83          solver     = new ADMMQPKKT();
84          converged  = false;
85          rho        = 0.1;
86      }
87  
88      /** {@inheritDoc} */
89      @Override
90      public ConvergenceChecker<LagrangeSolution> getConvergenceChecker() {
91          return checker;
92      }
93  
94      /** {@inheritDoc} */
95      @Override
96      public LagrangeSolution optimize(OptimizationData... optData) {
97          return super.optimize(optData);
98      }
99  
100     /** {@inheritDoc} */
101     @Override
102     protected void parseOptimizationData(OptimizationData... optData) {
103         super.parseOptimizationData(optData);
104         for (OptimizationData data: optData) {
105 
106              if (data instanceof ObjectiveFunction) {
107                 function = (QuadraticFunction) ((ObjectiveFunction) data).getObjectiveFunction();
108                 continue;
109             }
110 
111             if (data instanceof LinearEqualityConstraint) {
112                 eqConstraint = (LinearEqualityConstraint) data;
113                 continue;
114             }
115             if (data instanceof LinearInequalityConstraint) {
116                 iqConstraint = (LinearInequalityConstraint) data;
117                 continue;
118             }
119 
120             if (data instanceof LinearBoundedConstraint) {
121                 bqConstraint = (LinearBoundedConstraint) data;
122                 continue;
123             }
124 
125             if (data instanceof ADMMQPOption) {
126                 settings = (ADMMQPOption) data;
127             }
128 
129         }
130         // if we got here, convexObjective exists
131         int n = function.dim();
132         if (eqConstraint != null) {
133             int nDual = eqConstraint.dimY();
134             if (nDual >= n) {
135                 throw new MathIllegalArgumentException(LocalizedOptimFormats.CONSTRAINTS_RANK, nDual, n);
136             }
137             int nTest = eqConstraint.getA().getColumnDimension();
138             if (nDual == 0) {
139                 throw new MathIllegalArgumentException(LocalizedCoreFormats.ZERO_NOT_ALLOWED);
140             }
141             MathUtils.checkDimension(nTest, n);
142         }
143 
144     }
145 
146     /** {@inheritDoc} */
147     @Override
148     public LagrangeSolution doOptimize() {
149         final int n = function.dim();
150         int me = 0;
151         int mi = 0;
152         int mb = 0;
153         int rhoUpdateCount = 0;
154 
155         //PHASE 1 First Solution
156 
157 
158        //QUADRATIC TERM
159         RealMatrix H = function.getP();
160        //GRADIENT
161         RealVector q = function.getQ();
162 
163 
164        //EQUALITY CONSTRAINT
165         if (eqConstraint != null) {
166             me = eqConstraint.dimY();
167         }
168        //INEQUALITY CONSTRAINT
169         if (iqConstraint != null) {
170             mi = iqConstraint.dimY();
171         }
172         //BOUNDED CONSTRAINT
173         if (bqConstraint != null) {
174             mb = bqConstraint.dimY();
175         }
176 
177         RealVector lb = new ArrayRealVector(me + mi + mb);
178         RealVector ub = new ArrayRealVector(me + mi + mb);
179 
180         //COMPOSE A MATRIX AND LOWER AND UPPER BOUND
181         RealMatrix A = new Array2DRowRealMatrix(me + mi + mb, n);
182         if (eqConstraint != null) {
183             A.setSubMatrix(eqConstraint.jacobian(null).getData(), 0, 0);
184             lb.setSubVector(0,eqConstraint.getLowerBound());
185             ub.setSubVector(0,eqConstraint.getUpperBound());
186         }
187         if (iqConstraint != null) {
188             A.setSubMatrix(iqConstraint.jacobian(null).getData(), me, 0);
189             ub.setSubVector(me,iqConstraint.getUpperBound());
190             lb.setSubVector(me,iqConstraint.getLowerBound());
191         }
192 
193         if (mb > 0) {
194             A.setSubMatrix(bqConstraint.jacobian(null).getData(), me + mi, 0);
195             ub.setSubVector(me + mi,bqConstraint.getUpperBound());
196             lb.setSubVector(me + mi,bqConstraint.getLowerBound());
197         }
198 
199         checker = new ADMMQPConvergenceChecker(H, A, q, settings.getEps(), settings.getEps());
200 
201         //SETUP WORKING MATRIX
202         RealMatrix Hw = H.copy();
203         RealMatrix Aw = A.copy();
204         RealVector qw = q.copy();
205         RealVector ubw = ub.copy();
206         RealVector lbw = lb.copy();
207         RealVector x;
208         if (getStartPoint() != null) {
209             x = new ArrayRealVector(getStartPoint());
210         } else {
211             x = new ArrayRealVector(function.dim());
212         }
213 
214         ADMMQPModifiedRuizEquilibrium dec = new ADMMQPModifiedRuizEquilibrium(H, A,q);
215 
216         if (settings.isScaling()) {
217            //
218             dec.normalize(settings.getEps(), settings.getScaleMaxIteration());
219             Hw = dec.getScaledH();
220             Aw = dec.getScaledA();
221             qw = dec.getScaledQ();
222             lbw = dec.getScaledLUb(lb);
223             ubw = dec.getScaledLUb(ub);
224 
225             x = dec.scaleX(x.copy());
226 
227         }
228 
229         final ADMMQPConvergenceChecker checkerRho = new ADMMQPConvergenceChecker(Hw, Aw, qw, settings.getEps(), settings.getEps());
230         //SETUP VECTOR SOLUTION
231 
232         RealVector z = Aw.operate(x);
233         RealVector y = new ArrayRealVector(me + mi + mb);
234 
235         solver.initialize(Hw, Aw, qw, me, lbw, ubw, rho, settings.getSigma(), settings.getAlpha());
236         RealVector xstar = null;
237         RealVector ystar = null;
238         RealVector zstar;
239 
240         while (iterations.getCount() <= iterations.getMaximalCount()) {
241             ADMMQPSolution sol = solver.iterate(x, y, z);
242             x = sol.getX();
243             y = sol.getLambda();
244             z = sol.getZ();
245             //new ArrayRealVector(me + mi + mb);
246             if (rhoUpdateCount < settings.getMaxRhoIteration()) {
247                 double rp       = checkerRho.residualPrime(x, z);
248                 double rd       = checkerRho.residualDual(x, y);
249                 double maxP     = checkerRho.maxPrimal(x, z);
250                 double maxD     = checkerRho.maxDual(x, y);
251                 boolean updated = manageRho(me, rp, rd, maxP, maxD);
252 
253                 if (updated) {
254                     ++rhoUpdateCount;
255                 }
256             }
257 
258 
259             if (settings.isScaling()) {
260 
261                 xstar = dec.unscaleX(x);
262                 ystar = dec.unscaleY(y);
263                 zstar = dec.unscaleZ(z);
264 
265             } else {
266 
267                 xstar = x.copy();
268                 ystar = y.copy();
269                 zstar = z.copy();
270 
271             }
272 
273             double rp        = checker.residualPrime(xstar, zstar);
274             double rd        = checker.residualDual(xstar, ystar);
275             double maxPrimal = checker.maxPrimal(xstar, zstar);
276             double maxDual   = checker.maxDual(xstar, ystar);
277 
278             if (checker.converged(rp, rd, maxPrimal, maxDual)) {
279                 converged = true;
280                 break;
281             }
282             iterations.increment();
283 
284         }
285 
286         //SOLUTION POLISHING
287         if (settings.isPolishing()) {
288             ADMMQPSolution finalSol = polish(Hw, Aw, qw, lbw, ubw, x, y, z);
289             if (settings.isScaling()) {
290                 xstar = dec.unscaleX(finalSol.getX());
291                 ystar = dec.unscaleY(finalSol.getLambda());
292             } else {
293                 xstar = finalSol.getX();
294                 ystar = finalSol.getLambda();
295             }
296         }
297         for (int i = 0; i < me + mi; i++) {
298             ystar.setEntry(i, -ystar.getEntry(i));
299         }
300 
301         return new LagrangeSolution(xstar, ystar, function.value(xstar));
302 
303     }
304 
305     /** Check if convergence has been reached.
306      * @return true if convergence has been reached
307      */
308     public boolean isConverged() {
309         return converged;
310     }
311 
312     /** Polish solution.
313      * @param H quadratic term matrix
314      * @param A constraint coefficients matrix
315      * @param q linear term matrix
316      * @param lb lower bound
317      * @param ub upper bound
318      * @param x primal problem solution
319      * @param y dual problem solution
320      * @param z auxiliary variable
321      * @return polished solution
322      */
323     private ADMMQPSolution polish(RealMatrix H, RealMatrix A, RealVector q, RealVector lb, RealVector ub,
324                                   RealVector x, RealVector y, RealVector z) {
325 
326         List<double[]> Aentry    = new ArrayList<>();
327         List<Double>  lubEntry   = new ArrayList<>();
328         List<Double>  yEntry     = new ArrayList<>();
329 
330         // FIND ACTIVE ON LOWER BAND
331         for (int j = 0; j < A.getRowDimension(); j++) {
332             if (z.getEntry(j) - lb.getEntry(j) < -y.getEntry(j)) {  // lower-active
333 
334                 Aentry.add(A.getRow(j));
335                 lubEntry.add(lb.getEntry(j));
336                 yEntry.add(y.getEntry(j));
337 
338             }
339         }
340         //FIND ACTIVE ON UPPER BAND
341         for (int j = 0; j < A.getRowDimension(); j++) {
342             if (-z.getEntry(j) + ub.getEntry(j) < y.getEntry(j)) { // lower-active
343 
344                 Aentry.add(A.getRow(j));
345                 lubEntry.add(ub.getEntry(j));
346                 yEntry.add(y.getEntry(j));
347 
348             }
349 
350         }
351         RealMatrix Aactive;
352         RealVector lub;
353 
354         RealVector ystar;
355         RealVector xstar = x.copy();
356         //!Aentry.isEmpty()
357         if (!Aentry.isEmpty()) {
358 
359             Aactive = new Array2DRowRealMatrix(Aentry.toArray(new double[0][]));
360             lub = new ArrayRealVector(lubEntry.toArray(new Double[0]));
361             ystar = new ArrayRealVector(yEntry.toArray(new Double[0]));
362             solver.initialize(H, Aactive, q, 0, lub, lub,
363                               settings.getSigma(), settings.getSigma(), settings.getAlpha());
364 
365             for (int i = 0; i < settings.getPolishIteration(); i++) {
366                 RealVector kttx = (H.operate(xstar)).add(Aactive.transpose().operate(ystar));
367                 RealVector ktty = Aactive.operate(xstar);
368                 RealVector b1 = q.mapMultiply(-1.0).subtract(kttx);
369                 RealVector b2 = lub.mapMultiply(1.0).subtract(ktty);
370                 ADMMQPSolution dxz = solver.solve(b1,b2);
371                 xstar = xstar.add(dxz.getX());
372                 ystar = ystar.add(dxz.getV());
373             }
374 
375             return new ADMMQPSolution(xstar, null, y, A.operate(xstar));
376 
377         } else {
378             return new ADMMQPSolution(x, null, y, z);
379         }
380     }
381 
382     /** Manage step size.
383      * @param me number of equality constraints
384      * @param rp primal residual
385      * @param rd dual residual
386      * @param maxPrimal primal vectors max
387      * @param maxDual dual vectors max
388      * @return true if rho has been updated
389      */
390     private boolean manageRho(int me, double rp, double rd, double maxPrimal, double maxDual) {
391         boolean updated = false;
392         if (settings.updateRho()) {
393 
394             // estimate new step size
395             double rhonew = FastMath.min(FastMath.max(rho * FastMath.sqrt((rp * maxDual) / (rd * maxPrimal)),
396                                                       settings.getRhoMin()),
397                                          settings.getRhoMax());
398 
399             if ((rhonew > rho * 5.0) || (rhonew < rho / 5.0)) {
400 
401                 rho = rhonew;
402                 updated = true;
403 
404                 solver.updateSigmaRho(settings.getSigma(), me, rho);
405             }
406         }
407         return updated;
408     }
409 
410 }