1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.hipparchus.optim.nonlinear.vector.constrained;
18
19 import java.util.ArrayList;
20 import java.util.List;
21
22 import org.hipparchus.exception.LocalizedCoreFormats;
23 import org.hipparchus.exception.MathIllegalArgumentException;
24 import org.hipparchus.linear.Array2DRowRealMatrix;
25 import org.hipparchus.linear.ArrayRealVector;
26 import org.hipparchus.linear.RealMatrix;
27 import org.hipparchus.linear.RealVector;
28 import org.hipparchus.optim.ConvergenceChecker;
29 import org.hipparchus.optim.LocalizedOptimFormats;
30 import org.hipparchus.optim.OptimizationData;
31 import org.hipparchus.optim.nonlinear.scalar.ObjectiveFunction;
32 import org.hipparchus.util.FastMath;
33 import org.hipparchus.util.MathUtils;
34
35
36
37
38
39
40
41
42
43
44
45
46
47 public class ADMMQPOptimizer extends QPOptimizer {
48
49
50 private ADMMQPOption settings;
51
52
53 private LinearEqualityConstraint eqConstraint;
54
55
56 private LinearInequalityConstraint iqConstraint;
57
58
59 private LinearBoundedConstraint bqConstraint;
60
61
62 private QuadraticFunction function;
63
64
65 private final ADMMQPKKT solver;
66
67
68 private ADMMQPConvergenceChecker checker;
69
70
71 private boolean converged;
72
73
74 private double rho;
75
76
77
78
79
80
81 public ADMMQPOptimizer() {
82 settings = new ADMMQPOption();
83 solver = new ADMMQPKKT();
84 converged = false;
85 rho = 0.1;
86 }
87
88
89 @Override
90 public ConvergenceChecker<LagrangeSolution> getConvergenceChecker() {
91 return checker;
92 }
93
94
95 @Override
96 public LagrangeSolution optimize(OptimizationData... optData) {
97 return super.optimize(optData);
98 }
99
100
101 @Override
102 protected void parseOptimizationData(OptimizationData... optData) {
103 super.parseOptimizationData(optData);
104 for (OptimizationData data: optData) {
105
106 if (data instanceof ObjectiveFunction) {
107 function = (QuadraticFunction) ((ObjectiveFunction) data).getObjectiveFunction();
108 continue;
109 }
110
111 if (data instanceof LinearEqualityConstraint) {
112 eqConstraint = (LinearEqualityConstraint) data;
113 continue;
114 }
115 if (data instanceof LinearInequalityConstraint) {
116 iqConstraint = (LinearInequalityConstraint) data;
117 continue;
118 }
119
120 if (data instanceof LinearBoundedConstraint) {
121 bqConstraint = (LinearBoundedConstraint) data;
122 continue;
123 }
124
125 if (data instanceof ADMMQPOption) {
126 settings = (ADMMQPOption) data;
127 }
128
129 }
130
131 int n = function.dim();
132 if (eqConstraint != null) {
133 int nDual = eqConstraint.dimY();
134 if (nDual >= n) {
135 throw new MathIllegalArgumentException(LocalizedOptimFormats.CONSTRAINTS_RANK, nDual, n);
136 }
137 int nTest = eqConstraint.getA().getColumnDimension();
138 if (nDual == 0) {
139 throw new MathIllegalArgumentException(LocalizedCoreFormats.ZERO_NOT_ALLOWED);
140 }
141 MathUtils.checkDimension(nTest, n);
142 }
143
144 }
145
146
147 @Override
148 public LagrangeSolution doOptimize() {
149 final int n = function.dim();
150 int me = 0;
151 int mi = 0;
152 int mb = 0;
153 int rhoUpdateCount = 0;
154
155
156
157
158
159 RealMatrix H = function.getP();
160
161 RealVector q = function.getQ();
162
163
164
165 if (eqConstraint != null) {
166 me = eqConstraint.dimY();
167 }
168
169 if (iqConstraint != null) {
170 mi = iqConstraint.dimY();
171 }
172
173 if (bqConstraint != null) {
174 mb = bqConstraint.dimY();
175 }
176
177 RealVector lb = new ArrayRealVector(me + mi + mb);
178 RealVector ub = new ArrayRealVector(me + mi + mb);
179
180
181 RealMatrix A = new Array2DRowRealMatrix(me + mi + mb, n);
182 if (eqConstraint != null) {
183 A.setSubMatrix(eqConstraint.jacobian(null).getData(), 0, 0);
184 lb.setSubVector(0,eqConstraint.getLowerBound());
185 ub.setSubVector(0,eqConstraint.getUpperBound());
186 }
187 if (iqConstraint != null) {
188 A.setSubMatrix(iqConstraint.jacobian(null).getData(), me, 0);
189 ub.setSubVector(me,iqConstraint.getUpperBound());
190 lb.setSubVector(me,iqConstraint.getLowerBound());
191 }
192
193 if (mb > 0) {
194 A.setSubMatrix(bqConstraint.jacobian(null).getData(), me + mi, 0);
195 ub.setSubVector(me + mi,bqConstraint.getUpperBound());
196 lb.setSubVector(me + mi,bqConstraint.getLowerBound());
197 }
198
199 checker = new ADMMQPConvergenceChecker(H, A, q, settings.getEps(), settings.getEps());
200
201
202 RealMatrix Hw = H.copy();
203 RealMatrix Aw = A.copy();
204 RealVector qw = q.copy();
205 RealVector ubw = ub.copy();
206 RealVector lbw = lb.copy();
207 RealVector x;
208 if (getStartPoint() != null) {
209 x = new ArrayRealVector(getStartPoint());
210 } else {
211 x = new ArrayRealVector(function.dim());
212 }
213
214 ADMMQPModifiedRuizEquilibrium dec = new ADMMQPModifiedRuizEquilibrium(H, A,q);
215
216 if (settings.isScaling()) {
217
218 dec.normalize(settings.getEps(), settings.getScaleMaxIteration());
219 Hw = dec.getScaledH();
220 Aw = dec.getScaledA();
221 qw = dec.getScaledQ();
222 lbw = dec.getScaledLUb(lb);
223 ubw = dec.getScaledLUb(ub);
224
225 x = dec.scaleX(x.copy());
226
227 }
228
229 final ADMMQPConvergenceChecker checkerRho = new ADMMQPConvergenceChecker(Hw, Aw, qw, settings.getEps(), settings.getEps());
230
231
232 RealVector z = Aw.operate(x);
233 RealVector y = new ArrayRealVector(me + mi + mb);
234
235 solver.initialize(Hw, Aw, qw, me, lbw, ubw, rho, settings.getSigma(), settings.getAlpha());
236 RealVector xstar = null;
237 RealVector ystar = null;
238 RealVector zstar;
239
240 while (iterations.getCount() <= iterations.getMaximalCount()) {
241 ADMMQPSolution sol = solver.iterate(x, y, z);
242 x = sol.getX();
243 y = sol.getLambda();
244 z = sol.getZ();
245
246 if (rhoUpdateCount < settings.getMaxRhoIteration()) {
247 double rp = checkerRho.residualPrime(x, z);
248 double rd = checkerRho.residualDual(x, y);
249 double maxP = checkerRho.maxPrimal(x, z);
250 double maxD = checkerRho.maxDual(x, y);
251 boolean updated = manageRho(me, rp, rd, maxP, maxD);
252
253 if (updated) {
254 ++rhoUpdateCount;
255 }
256 }
257
258
259 if (settings.isScaling()) {
260
261 xstar = dec.unscaleX(x);
262 ystar = dec.unscaleY(y);
263 zstar = dec.unscaleZ(z);
264
265 } else {
266
267 xstar = x.copy();
268 ystar = y.copy();
269 zstar = z.copy();
270
271 }
272
273 double rp = checker.residualPrime(xstar, zstar);
274 double rd = checker.residualDual(xstar, ystar);
275 double maxPrimal = checker.maxPrimal(xstar, zstar);
276 double maxDual = checker.maxDual(xstar, ystar);
277
278 if (checker.converged(rp, rd, maxPrimal, maxDual)) {
279 converged = true;
280 break;
281 }
282 iterations.increment();
283
284 }
285
286
287 if (settings.isPolishing()) {
288 ADMMQPSolution finalSol = polish(Hw, Aw, qw, lbw, ubw, x, y, z);
289 if (settings.isScaling()) {
290 xstar = dec.unscaleX(finalSol.getX());
291 ystar = dec.unscaleY(finalSol.getLambda());
292 } else {
293 xstar = finalSol.getX();
294 ystar = finalSol.getLambda();
295 }
296 }
297 for (int i = 0; i < me + mi; i++) {
298 ystar.setEntry(i, -ystar.getEntry(i));
299 }
300
301 return new LagrangeSolution(xstar, ystar, function.value(xstar));
302
303 }
304
305
306
307
308 public boolean isConverged() {
309 return converged;
310 }
311
312
313
314
315
316
317
318
319
320
321
322
323 private ADMMQPSolution polish(RealMatrix H, RealMatrix A, RealVector q, RealVector lb, RealVector ub,
324 RealVector x, RealVector y, RealVector z) {
325
326 List<double[]> Aentry = new ArrayList<>();
327 List<Double> lubEntry = new ArrayList<>();
328 List<Double> yEntry = new ArrayList<>();
329
330
331 for (int j = 0; j < A.getRowDimension(); j++) {
332 if (z.getEntry(j) - lb.getEntry(j) < -y.getEntry(j)) {
333
334 Aentry.add(A.getRow(j));
335 lubEntry.add(lb.getEntry(j));
336 yEntry.add(y.getEntry(j));
337
338 }
339 }
340
341 for (int j = 0; j < A.getRowDimension(); j++) {
342 if (-z.getEntry(j) + ub.getEntry(j) < y.getEntry(j)) {
343
344 Aentry.add(A.getRow(j));
345 lubEntry.add(ub.getEntry(j));
346 yEntry.add(y.getEntry(j));
347
348 }
349
350 }
351 RealMatrix Aactive;
352 RealVector lub;
353
354 RealVector ystar;
355 RealVector xstar = x.copy();
356
357 if (!Aentry.isEmpty()) {
358
359 Aactive = new Array2DRowRealMatrix(Aentry.toArray(new double[0][]));
360 lub = new ArrayRealVector(lubEntry.toArray(new Double[0]));
361 ystar = new ArrayRealVector(yEntry.toArray(new Double[0]));
362 solver.initialize(H, Aactive, q, 0, lub, lub,
363 settings.getSigma(), settings.getSigma(), settings.getAlpha());
364
365 for (int i = 0; i < settings.getPolishIteration(); i++) {
366 RealVector kttx = (H.operate(xstar)).add(Aactive.transpose().operate(ystar));
367 RealVector ktty = Aactive.operate(xstar);
368 RealVector b1 = q.mapMultiply(-1.0).subtract(kttx);
369 RealVector b2 = lub.mapMultiply(1.0).subtract(ktty);
370 ADMMQPSolution dxz = solver.solve(b1,b2);
371 xstar = xstar.add(dxz.getX());
372 ystar = ystar.add(dxz.getV());
373 }
374
375 return new ADMMQPSolution(xstar, null, y, A.operate(xstar));
376
377 } else {
378 return new ADMMQPSolution(x, null, y, z);
379 }
380 }
381
382
383
384
385
386
387
388
389
390 private boolean manageRho(int me, double rp, double rd, double maxPrimal, double maxDual) {
391 boolean updated = false;
392 if (settings.updateRho()) {
393
394
395 double rhonew = FastMath.min(FastMath.max(rho * FastMath.sqrt((rp * maxDual) / (rd * maxPrimal)),
396 settings.getRhoMin()),
397 settings.getRhoMax());
398
399 if ((rhonew > rho * 5.0) || (rhonew < rho / 5.0)) {
400
401 rho = rhonew;
402 updated = true;
403
404 solver.updateSigmaRho(settings.getSigma(), me, rho);
405 }
406 }
407 return updated;
408 }
409
410 }