1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23 package org.hipparchus.linear;
24
25 import java.util.Arrays;
26 import java.util.function.Predicate;
27
28 import org.hipparchus.CalculusFieldElement;
29 import org.hipparchus.exception.LocalizedCoreFormats;
30 import org.hipparchus.exception.MathIllegalArgumentException;
31 import org.hipparchus.util.FastMath;
32 import org.hipparchus.util.MathArrays;
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51 public class FieldQRDecomposition<T extends CalculusFieldElement<T>> {
52
53
54
55
56
57
58 private T[][] qrt;
59
60 private T[] rDiag;
61
62 private FieldMatrix<T> cachedQ;
63
64 private FieldMatrix<T> cachedQT;
65
66 private FieldMatrix<T> cachedR;
67
68 private FieldMatrix<T> cachedH;
69
70 private final T threshold;
71
72 private final Predicate<T> zeroChecker;
73
74
75
76
77
78
79
80
81
82 public FieldQRDecomposition(FieldMatrix<T> matrix) {
83 this(matrix, matrix.getField().getZero());
84 }
85
86
87
88
89
90
91
92 public FieldQRDecomposition(FieldMatrix<T> matrix, T threshold) {
93 this(matrix, threshold, e -> e.isZero());
94 }
95
96
97
98
99
100
101
102
103 public FieldQRDecomposition(FieldMatrix<T> matrix, T threshold, Predicate<T> zeroChecker) {
104 this.threshold = threshold;
105 this.zeroChecker = zeroChecker;
106
107 final int m = matrix.getRowDimension();
108 final int n = matrix.getColumnDimension();
109 qrt = matrix.transpose().getData();
110 rDiag = MathArrays.buildArray(threshold.getField(),FastMath.min(m, n));
111 cachedQ = null;
112 cachedQT = null;
113 cachedR = null;
114 cachedH = null;
115
116 decompose(qrt);
117
118 }
119
120
121
122
123 protected void decompose(T[][] matrix) {
124 for (int minor = 0; minor < FastMath.min(matrix.length, matrix[0].length); minor++) {
125 performHouseholderReflection(minor, matrix);
126 }
127 }
128
129
130
131
132
133 protected void performHouseholderReflection(int minor, T[][] matrix) {
134
135 final T[] qrtMinor = matrix[minor];
136 final T zero = threshold.getField().getZero();
137
138
139
140
141
142
143
144 T xNormSqr = zero;
145 for (int row = minor; row < qrtMinor.length; row++) {
146 final T c = qrtMinor[row];
147 xNormSqr = xNormSqr.add(c.square());
148 }
149 final T a = (qrtMinor[minor].getReal() > 0) ? xNormSqr.sqrt().negate() : xNormSqr.sqrt();
150 rDiag[minor] = a;
151
152 if (!zeroChecker.test(a)) {
153
154
155
156
157
158
159
160
161
162 qrtMinor[minor] = qrtMinor[minor].subtract(a);
163
164
165
166
167
168
169
170
171
172
173
174
175
176 for (int col = minor+1; col < matrix.length; col++) {
177 final T[] qrtCol = matrix[col];
178 T alpha = zero;
179 for (int row = minor; row < qrtCol.length; row++) {
180 alpha = alpha.subtract(qrtCol[row].multiply(qrtMinor[row]));
181 }
182 alpha = alpha.divide(a.multiply(qrtMinor[minor]));
183
184
185 for (int row = minor; row < qrtCol.length; row++) {
186 qrtCol[row] = qrtCol[row].subtract(alpha.multiply(qrtMinor[row]));
187 }
188 }
189 }
190 }
191
192
193
194
195
196
197
198 public FieldMatrix<T> getR() {
199
200 if (cachedR == null) {
201
202
203 final int n = qrt.length;
204 final int m = qrt[0].length;
205 T[][] ra = MathArrays.buildArray(threshold.getField(), m, n);
206
207 for (int row = FastMath.min(m, n) - 1; row >= 0; row--) {
208 ra[row][row] = rDiag[row];
209 for (int col = row + 1; col < n; col++) {
210 ra[row][col] = qrt[col][row];
211 }
212 }
213 cachedR = MatrixUtils.createFieldMatrix(ra);
214 }
215
216
217 return cachedR;
218 }
219
220
221
222
223
224
225 public FieldMatrix<T> getQ() {
226 if (cachedQ == null) {
227 cachedQ = getQT().transpose();
228 }
229 return cachedQ;
230 }
231
232
233
234
235
236
237 public FieldMatrix<T> getQT() {
238 if (cachedQT == null) {
239
240
241 final int n = qrt.length;
242 final int m = qrt[0].length;
243 T[][] qta = MathArrays.buildArray(threshold.getField(), m, m);
244
245
246
247
248
249
250 for (int minor = m - 1; minor >= FastMath.min(m, n); minor--) {
251 qta[minor][minor] = threshold.getField().getOne();
252 }
253
254 for (int minor = FastMath.min(m, n)-1; minor >= 0; minor--){
255 final T[] qrtMinor = qrt[minor];
256 qta[minor][minor] = threshold.getField().getOne();
257 if (!qrtMinor[minor].isZero()) {
258 for (int col = minor; col < m; col++) {
259 T alpha = threshold.getField().getZero();
260 for (int row = minor; row < m; row++) {
261 alpha = alpha.subtract(qta[col][row].multiply(qrtMinor[row]));
262 }
263 alpha = alpha.divide(rDiag[minor].multiply(qrtMinor[minor]));
264
265 for (int row = minor; row < m; row++) {
266 qta[col][row] = qta[col][row].add(alpha.negate().multiply(qrtMinor[row]));
267 }
268 }
269 }
270 }
271 cachedQT = MatrixUtils.createFieldMatrix(qta);
272 }
273
274
275 return cachedQT;
276 }
277
278
279
280
281
282
283
284
285 public FieldMatrix<T> getH() {
286 if (cachedH == null) {
287
288 final int n = qrt.length;
289 final int m = qrt[0].length;
290 T[][] ha = MathArrays.buildArray(threshold.getField(), m, n);
291 for (int i = 0; i < m; ++i) {
292 for (int j = 0; j < FastMath.min(i + 1, n); ++j) {
293 ha[i][j] = qrt[j][i].divide(rDiag[j].negate());
294 }
295 }
296 cachedH = MatrixUtils.createFieldMatrix(ha);
297 }
298
299
300 return cachedH;
301 }
302
303
304
305
306
307
308
309
310
311
312
313
314
315 public FieldDecompositionSolver<T> getSolver() {
316 return new FieldSolver();
317 }
318
319
320
321
322 private class FieldSolver implements FieldDecompositionSolver<T>{
323
324
325 @Override
326 public boolean isNonSingular() {
327 return !checkSingular(rDiag, threshold, false);
328 }
329
330
331 @Override
332 public FieldVector<T> solve(FieldVector<T> b) {
333 final int n = qrt.length;
334 final int m = qrt[0].length;
335 if (b.getDimension() != m) {
336 throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH,
337 b.getDimension(), m);
338 }
339 checkSingular(rDiag, threshold, true);
340
341 final T[] x =MathArrays.buildArray(threshold.getField(),n);
342 final T[] y = b.toArray();
343
344
345 for (int minor = 0; minor < FastMath.min(m, n); minor++) {
346
347 final T[] qrtMinor = qrt[minor];
348 T dotProduct = threshold.getField().getZero();
349 for (int row = minor; row < m; row++) {
350 dotProduct = dotProduct.add(y[row].multiply(qrtMinor[row]));
351 }
352 dotProduct = dotProduct.divide(rDiag[minor].multiply(qrtMinor[minor]));
353
354 for (int row = minor; row < m; row++) {
355 y[row] = y[row].add(dotProduct.multiply(qrtMinor[row]));
356 }
357 }
358
359
360 for (int row = rDiag.length - 1; row >= 0; --row) {
361 y[row] = y[row].divide(rDiag[row]);
362 final T yRow = y[row];
363 final T[] qrtRow = qrt[row];
364 x[row] = yRow;
365 for (int i = 0; i < row; i++) {
366 y[i] = y[i].subtract(yRow.multiply(qrtRow[i]));
367 }
368 }
369
370 return new ArrayFieldVector<T>(x, false);
371 }
372
373
374 @Override
375 public FieldMatrix<T> solve(FieldMatrix<T> b) {
376 final int n = qrt.length;
377 final int m = qrt[0].length;
378 if (b.getRowDimension() != m) {
379 throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH,
380 b.getRowDimension(), m);
381 }
382 checkSingular(rDiag, threshold, true);
383
384 final int columns = b.getColumnDimension();
385 final int blockSize = BlockFieldMatrix.BLOCK_SIZE;
386 final int cBlocks = (columns + blockSize - 1) / blockSize;
387 final T[][] xBlocks = BlockFieldMatrix.createBlocksLayout(threshold.getField(),n, columns);
388 final T[][] y = MathArrays.buildArray(threshold.getField(), b.getRowDimension(), blockSize);
389 final T[] alpha = MathArrays.buildArray(threshold.getField(), blockSize);
390
391 for (int kBlock = 0; kBlock < cBlocks; ++kBlock) {
392 final int kStart = kBlock * blockSize;
393 final int kEnd = FastMath.min(kStart + blockSize, columns);
394 final int kWidth = kEnd - kStart;
395
396
397 b.copySubMatrix(0, m - 1, kStart, kEnd - 1, y);
398
399
400 for (int minor = 0; minor < FastMath.min(m, n); minor++) {
401 final T[] qrtMinor = qrt[minor];
402 final T factor = rDiag[minor].multiply(qrtMinor[minor]).reciprocal();
403
404 Arrays.fill(alpha, 0, kWidth, threshold.getField().getZero());
405 for (int row = minor; row < m; ++row) {
406 final T d = qrtMinor[row];
407 final T[] yRow = y[row];
408 for (int k = 0; k < kWidth; ++k) {
409 alpha[k] = alpha[k].add(d.multiply(yRow[k]));
410 }
411 }
412
413 for (int k = 0; k < kWidth; ++k) {
414 alpha[k] = alpha[k].multiply(factor);
415 }
416
417 for (int row = minor; row < m; ++row) {
418 final T d = qrtMinor[row];
419 final T[] yRow = y[row];
420 for (int k = 0; k < kWidth; ++k) {
421 yRow[k] = yRow[k].add(alpha[k].multiply(d));
422 }
423 }
424 }
425
426
427 for (int j = rDiag.length - 1; j >= 0; --j) {
428 final int jBlock = j / blockSize;
429 final int jStart = jBlock * blockSize;
430 final T factor = rDiag[j].reciprocal();
431 final T[] yJ = y[j];
432 final T[] xBlock = xBlocks[jBlock * cBlocks + kBlock];
433 int index = (j - jStart) * kWidth;
434 for (int k = 0; k < kWidth; ++k) {
435 yJ[k] =yJ[k].multiply(factor);
436 xBlock[index++] = yJ[k];
437 }
438
439 final T[] qrtJ = qrt[j];
440 for (int i = 0; i < j; ++i) {
441 final T rIJ = qrtJ[i];
442 final T[] yI = y[i];
443 for (int k = 0; k < kWidth; ++k) {
444 yI[k] = yI[k].subtract(yJ[k].multiply(rIJ));
445 }
446 }
447 }
448 }
449
450 return new BlockFieldMatrix<T>(n, columns, xBlocks, false);
451 }
452
453
454
455
456
457 @Override
458 public FieldMatrix<T> getInverse() {
459 return solve(MatrixUtils.createFieldIdentityMatrix(threshold.getField(), qrt[0].length));
460 }
461
462
463
464
465
466
467
468
469
470
471
472
473
474 private boolean checkSingular(T[] diag,
475 T min,
476 boolean raise) {
477 final int len = diag.length;
478 for (int i = 0; i < len; i++) {
479 final T d = diag[i];
480 if (FastMath.abs(d.getReal()) <= min.getReal()) {
481 if (raise) {
482 throw new MathIllegalArgumentException(LocalizedCoreFormats.SINGULAR_MATRIX);
483 } else {
484 return true;
485 }
486 }
487 }
488 return false;
489 }
490
491
492 @Override
493 public int getRowDimension() {
494 return qrt[0].length;
495 }
496
497
498 @Override
499 public int getColumnDimension() {
500 return qrt.length;
501 }
502
503 }
504 }