1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.hipparchus.optim.nonlinear.vector.constrained;
18
19
20 import org.hipparchus.linear.ArrayRealVector;
21 import org.hipparchus.linear.DecompositionSolver;
22 import org.hipparchus.linear.EigenDecompositionSymmetric;
23 import org.hipparchus.linear.MatrixUtils;
24 import org.hipparchus.linear.RealMatrix;
25 import org.hipparchus.linear.RealVector;
26 import org.hipparchus.util.FastMath;
27
28
29
30
31 public class ADMMQPKKT implements KarushKuhnTuckerSolver<ADMMQPSolution> {
32
33
34 private RealMatrix H;
35
36
37 private RealVector q;
38
39
40 private RealMatrix A;
41
42
43 private double sigma;
44
45
46 private RealMatrix R;
47
48
49 private RealMatrix Rinv;
50
51
52 private RealVector lb;
53
54
55 private RealVector ub;
56
57
58 private double alpha;
59
60
61 private RealMatrix M;
62
63
64 private DecompositionSolver dsX;
65
66
67
68
69
70
71
72
73 ADMMQPKKT() {
74
75 }
76
77
78 @Override
79 public ADMMQPSolution solve(RealVector b1, final RealVector b2) {
80 RealVector z = dsX.solve(new ArrayRealVector((ArrayRealVector) b1,b2));
81 return new ADMMQPSolution(z.getSubVector(0,b1.getDimension()), z.getSubVector(b1.getDimension(), b2.getDimension()));
82 }
83
84
85
86
87
88
89 public void updateSigmaRho(double newSigma, int me, double rho) {
90 this.sigma = newSigma;
91 this.H = H.add(MatrixUtils.createRealIdentityMatrix(H.getColumnDimension()).scalarMultiply(newSigma));
92 createPenaltyMatrix(me, rho);
93 M = MatrixUtils.createRealMatrix(H.getRowDimension() + A.getRowDimension(),
94 H.getRowDimension() + A.getRowDimension());
95 M.setSubMatrix(H.getData(), 0,0);
96 M.setSubMatrix(A.getData(), H.getRowDimension(),0);
97 M.setSubMatrix(A.transpose().getData(), 0, H.getRowDimension());
98 M.setSubMatrix(Rinv.scalarMultiply(-1.0).getData(), H.getRowDimension(),H.getRowDimension());
99 dsX = new EigenDecompositionSymmetric(M).getSolver();
100 }
101
102
103
104
105
106
107
108
109
110
111
112
113 public void initialize(RealMatrix newH, RealMatrix newA, RealVector newQ,
114 int me, RealVector newLb, RealVector newUb,
115 double rho, double newSigma, double newAlpha) {
116 this.lb = newLb;
117 this.ub = newUb;
118 this.alpha = newAlpha;
119 this.sigma = newSigma;
120 this.H = newH.add(MatrixUtils.createRealIdentityMatrix(newH.getColumnDimension()).scalarMultiply(newSigma));
121 this.A = newA.copy();
122 this.q = newQ.copy();
123 createPenaltyMatrix(me, rho);
124
125 M = MatrixUtils.createRealMatrix(newH.getRowDimension() + newA.getRowDimension(),
126 newH.getRowDimension() + newA.getRowDimension());
127 M.setSubMatrix(newH.getData(),0,0);
128 M.setSubMatrix(newA.getData(),newH.getRowDimension(),0);
129 M.setSubMatrix(newA.transpose().getData(),0,newH.getRowDimension());
130 M.setSubMatrix(Rinv.scalarMultiply(-1.0).getData(),newH.getRowDimension(),newH.getRowDimension());
131 dsX = new EigenDecompositionSymmetric(M).getSolver();
132 }
133
134 private void createPenaltyMatrix(int me, double rho) {
135 this.R = MatrixUtils.createRealIdentityMatrix(A.getRowDimension());
136
137 for (int i = 0; i < R.getRowDimension(); i++) {
138 if (i < me) {
139 R.setEntry(i, i, rho * 1000.0);
140
141 } else {
142 R.setEntry(i, i, rho);
143
144 }
145 }
146 this.Rinv = MatrixUtils.inverse(R);
147 }
148
149
150 @Override
151 public ADMMQPSolution iterate(RealVector... previousSol) {
152 double onealfa = 1.0 - alpha;
153
154 RealVector xold = previousSol[0].copy();
155 RealVector yold = previousSol[1].copy();
156 RealVector zold = previousSol[2].copy();
157
158
159 RealVector b1 = previousSol[0].mapMultiply(sigma).subtract(q);
160 RealVector b2 = previousSol[2].subtract(Rinv.operate(previousSol[1]));
161
162
163 ADMMQPSolution sol = solve(b1, b2);
164 RealVector xtilde = sol.getX();
165 RealVector vtilde = sol.getV();
166
167
168 RealVector ztilde = zold.add(Rinv.operate(vtilde.subtract(yold)));
169
170 previousSol[0] = xtilde.mapMultiply(alpha).add(xold.mapMultiply(onealfa));
171
172
173 RealVector zpartial = ztilde.mapMultiply(alpha).add(zold.mapMultiply(onealfa)).add(Rinv.operate(yold));
174
175
176 for (int j = 0; j < previousSol[2].getDimension(); j++) {
177 previousSol[2].setEntry(j, FastMath.min(FastMath.max(zpartial.getEntry(j), lb.getEntry(j)), ub.getEntry(j)));
178 }
179
180
181 RealVector ytilde = ztilde.mapMultiply(alpha).add(zold.mapMultiply(onealfa).subtract(previousSol[2]));
182 previousSol[1] = yold.add(R.operate(ytilde));
183
184 return new ADMMQPSolution(previousSol[0], vtilde, previousSol[1], previousSol[2]);
185 }
186
187 }