1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23 package org.hipparchus.linear;
24
25 import java.util.Arrays;
26
27 import org.hipparchus.exception.LocalizedCoreFormats;
28 import org.hipparchus.exception.MathIllegalArgumentException;
29 import org.hipparchus.util.FastMath;
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55 public class QRDecomposition {
56
57
58
59
60
61
62 private double[][] qrt;
63
64 private double[] rDiag;
65
66 private RealMatrix cachedQ;
67
68 private RealMatrix cachedQT;
69
70 private RealMatrix cachedR;
71
72 private RealMatrix cachedH;
73
74 private final double threshold;
75
76
77
78
79
80
81
82
83
84 public QRDecomposition(RealMatrix matrix) {
85 this(matrix, 0d);
86 }
87
88
89
90
91
92
93
94 public QRDecomposition(RealMatrix matrix,
95 double threshold) {
96 this.threshold = threshold;
97
98 final int m = matrix.getRowDimension();
99 final int n = matrix.getColumnDimension();
100 qrt = matrix.transpose().getData();
101 rDiag = new double[FastMath.min(m, n)];
102 cachedQ = null;
103 cachedQT = null;
104 cachedR = null;
105 cachedH = null;
106
107 decompose(qrt);
108
109 }
110
111
112
113
114 protected void decompose(double[][] matrix) {
115 for (int minor = 0; minor < FastMath.min(matrix.length, matrix[0].length); minor++) {
116 performHouseholderReflection(minor, matrix);
117 }
118 }
119
120
121
122
123
124 protected void performHouseholderReflection(int minor, double[][] matrix) {
125
126 final double[] qrtMinor = matrix[minor];
127
128
129
130
131
132
133
134
135 double xNormSqr = 0;
136 for (int row = minor; row < qrtMinor.length; row++) {
137 final double c = qrtMinor[row];
138 xNormSqr += c * c;
139 }
140 final double a = (qrtMinor[minor] > 0) ? -FastMath.sqrt(xNormSqr) : FastMath.sqrt(xNormSqr);
141 rDiag[minor] = a;
142
143 if (a != 0.0) {
144
145
146
147
148
149
150
151
152
153 qrtMinor[minor] -= a;
154
155
156
157
158
159
160
161
162
163
164
165
166
167 for (int col = minor+1; col < matrix.length; col++) {
168 final double[] qrtCol = matrix[col];
169 double alpha = 0;
170 for (int row = minor; row < qrtCol.length; row++) {
171 alpha -= qrtCol[row] * qrtMinor[row];
172 }
173 alpha /= a * qrtMinor[minor];
174
175
176 for (int row = minor; row < qrtCol.length; row++) {
177 qrtCol[row] -= alpha * qrtMinor[row];
178 }
179 }
180 }
181 }
182
183
184
185
186
187
188
189 public RealMatrix getR() {
190
191 if (cachedR == null) {
192
193
194 final int n = qrt.length;
195 final int m = qrt[0].length;
196 double[][] ra = new double[m][n];
197
198 for (int row = FastMath.min(m, n) - 1; row >= 0; row--) {
199 ra[row][row] = rDiag[row];
200 for (int col = row + 1; col < n; col++) {
201 ra[row][col] = qrt[col][row];
202 }
203 }
204 cachedR = MatrixUtils.createRealMatrix(ra);
205 }
206
207
208 return cachedR;
209 }
210
211
212
213
214
215
216 public RealMatrix getQ() {
217 if (cachedQ == null) {
218 cachedQ = getQT().transpose();
219 }
220 return cachedQ;
221 }
222
223
224
225
226
227
228 public RealMatrix getQT() {
229 if (cachedQT == null) {
230
231
232 final int n = qrt.length;
233 final int m = qrt[0].length;
234 double[][] qta = new double[m][m];
235
236
237
238
239
240
241 for (int minor = m - 1; minor >= FastMath.min(m, n); minor--) {
242 qta[minor][minor] = 1.0d;
243 }
244
245 for (int minor = FastMath.min(m, n)-1; minor >= 0; minor--){
246 final double[] qrtMinor = qrt[minor];
247 qta[minor][minor] = 1.0d;
248 if (qrtMinor[minor] != 0.0) {
249 for (int col = minor; col < m; col++) {
250 double alpha = 0;
251 for (int row = minor; row < m; row++) {
252 alpha -= qta[col][row] * qrtMinor[row];
253 }
254 alpha /= rDiag[minor] * qrtMinor[minor];
255
256 for (int row = minor; row < m; row++) {
257 qta[col][row] += -alpha * qrtMinor[row];
258 }
259 }
260 }
261 }
262 cachedQT = MatrixUtils.createRealMatrix(qta);
263 }
264
265
266 return cachedQT;
267 }
268
269
270
271
272
273
274
275
276 public RealMatrix getH() {
277 if (cachedH == null) {
278
279 final int n = qrt.length;
280 final int m = qrt[0].length;
281 double[][] ha = new double[m][n];
282 for (int i = 0; i < m; ++i) {
283 for (int j = 0; j < FastMath.min(i + 1, n); ++j) {
284 ha[i][j] = qrt[j][i] / -rDiag[j];
285 }
286 }
287 cachedH = MatrixUtils.createRealMatrix(ha);
288 }
289
290
291 return cachedH;
292 }
293
294
295
296
297
298
299
300
301
302
303
304
305
306 public DecompositionSolver getSolver() {
307 return new Solver();
308 }
309
310
311 private class Solver implements DecompositionSolver {
312
313
314 @Override
315 public boolean isNonSingular() {
316 return !checkSingular(rDiag, threshold, false);
317 }
318
319
320 @Override
321 public RealVector solve(RealVector b) {
322 final int n = qrt.length;
323 final int m = qrt[0].length;
324 if (b.getDimension() != m) {
325 throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH,
326 b.getDimension(), m);
327 }
328 checkSingular(rDiag, threshold, true);
329
330 final double[] x = new double[n];
331 final double[] y = b.toArray();
332
333
334 for (int minor = 0; minor < FastMath.min(m, n); minor++) {
335
336 final double[] qrtMinor = qrt[minor];
337 double dotProduct = 0;
338 for (int row = minor; row < m; row++) {
339 dotProduct += y[row] * qrtMinor[row];
340 }
341 dotProduct /= rDiag[minor] * qrtMinor[minor];
342
343 for (int row = minor; row < m; row++) {
344 y[row] += dotProduct * qrtMinor[row];
345 }
346 }
347
348
349 for (int row = rDiag.length - 1; row >= 0; --row) {
350 y[row] /= rDiag[row];
351 final double yRow = y[row];
352 final double[] qrtRow = qrt[row];
353 x[row] = yRow;
354 for (int i = 0; i < row; i++) {
355 y[i] -= yRow * qrtRow[i];
356 }
357 }
358
359 return new ArrayRealVector(x, false);
360 }
361
362
363 @Override
364 public RealMatrix solve(RealMatrix b) {
365 final int n = qrt.length;
366 final int m = qrt[0].length;
367 if (b.getRowDimension() != m) {
368 throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH,
369 b.getRowDimension(), m);
370 }
371 checkSingular(rDiag, threshold, true);
372
373 final int columns = b.getColumnDimension();
374 final int blockSize = BlockRealMatrix.BLOCK_SIZE;
375 final int cBlocks = (columns + blockSize - 1) / blockSize;
376 final double[][] xBlocks = BlockRealMatrix.createBlocksLayout(n, columns);
377 final double[][] y = new double[b.getRowDimension()][blockSize];
378 final double[] alpha = new double[blockSize];
379
380 for (int kBlock = 0; kBlock < cBlocks; ++kBlock) {
381 final int kStart = kBlock * blockSize;
382 final int kEnd = FastMath.min(kStart + blockSize, columns);
383 final int kWidth = kEnd - kStart;
384
385
386 b.copySubMatrix(0, m - 1, kStart, kEnd - 1, y);
387
388
389 for (int minor = 0; minor < FastMath.min(m, n); minor++) {
390 final double[] qrtMinor = qrt[minor];
391 final double factor = 1.0 / (rDiag[minor] * qrtMinor[minor]);
392
393 Arrays.fill(alpha, 0, kWidth, 0.0);
394 for (int row = minor; row < m; ++row) {
395 final double d = qrtMinor[row];
396 final double[] yRow = y[row];
397 for (int k = 0; k < kWidth; ++k) {
398 alpha[k] += d * yRow[k];
399 }
400 }
401 for (int k = 0; k < kWidth; ++k) {
402 alpha[k] *= factor;
403 }
404
405 for (int row = minor; row < m; ++row) {
406 final double d = qrtMinor[row];
407 final double[] yRow = y[row];
408 for (int k = 0; k < kWidth; ++k) {
409 yRow[k] += alpha[k] * d;
410 }
411 }
412 }
413
414
415 for (int j = rDiag.length - 1; j >= 0; --j) {
416 final int jBlock = j / blockSize;
417 final int jStart = jBlock * blockSize;
418 final double factor = 1.0 / rDiag[j];
419 final double[] yJ = y[j];
420 final double[] xBlock = xBlocks[jBlock * cBlocks + kBlock];
421 int index = (j - jStart) * kWidth;
422 for (int k = 0; k < kWidth; ++k) {
423 yJ[k] *= factor;
424 xBlock[index++] = yJ[k];
425 }
426
427 final double[] qrtJ = qrt[j];
428 for (int i = 0; i < j; ++i) {
429 final double rIJ = qrtJ[i];
430 final double[] yI = y[i];
431 for (int k = 0; k < kWidth; ++k) {
432 yI[k] -= yJ[k] * rIJ;
433 }
434 }
435 }
436 }
437
438 return new BlockRealMatrix(n, columns, xBlocks, false);
439 }
440
441
442
443
444
445 @Override
446 public RealMatrix getInverse() {
447 return solve(MatrixUtils.createRealIdentityMatrix(qrt[0].length));
448 }
449
450
451
452
453
454
455
456
457
458
459
460
461
462 private boolean checkSingular(double[] diag, double min, boolean raise) {
463 final int len = diag.length;
464 for (int i = 0; i < len; i++) {
465 final double d = diag[i];
466 if (FastMath.abs(d) <= min) {
467 if (raise) {
468 throw new MathIllegalArgumentException(LocalizedCoreFormats.SINGULAR_MATRIX);
469 } else {
470 return true;
471 }
472 }
473 }
474 return false;
475 }
476
477
478 @Override
479 public int getRowDimension() {
480 return qrt[0].length;
481 }
482
483
484 @Override
485 public int getColumnDimension() {
486 return qrt.length;
487 }
488
489 }
490 }