1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * https://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 /*
19 * This is not the original file distributed by the Apache Software Foundation
20 * It has been modified by the Hipparchus project
21 */
22
23 package org.hipparchus.linear;
24
25 import java.util.Arrays;
26 import java.util.function.Predicate;
27
28 import org.hipparchus.CalculusFieldElement;
29 import org.hipparchus.FieldElement;
30 import org.hipparchus.exception.LocalizedCoreFormats;
31 import org.hipparchus.exception.MathIllegalArgumentException;
32 import org.hipparchus.util.FastMath;
33 import org.hipparchus.util.MathArrays;
34
35
36 /**
37 * Calculates the QR-decomposition of a field matrix.
38 * <p>The QR-decomposition of a matrix A consists of two matrices Q and R
39 * that satisfy: A = QR, Q is orthogonal (Q<sup>T</sup>Q = I), and R is
40 * upper triangular. If A is m×n, Q is m×m and R m×n.</p>
41 * <p>This class compute the decomposition using Householder reflectors.</p>
42 * <p>For efficiency purposes, the decomposition in packed form is transposed.
43 * This allows inner loop to iterate inside rows, which is much more cache-efficient
44 * in Java.</p>
45 * <p>This class is based on the class {@link QRDecomposition}.</p>
46 *
47 * @param <T> type of the underlying field elements
48 * @see <a href="http://mathworld.wolfram.com/QRDecomposition.html">MathWorld</a>
49 * @see <a href="http://en.wikipedia.org/wiki/QR_decomposition">Wikipedia</a>
50 *
51 */
52 public class FieldQRDecomposition<T extends CalculusFieldElement<T>> {
53 /**
54 * A packed TRANSPOSED representation of the QR decomposition.
55 * <p>The elements BELOW the diagonal are the elements of the UPPER triangular
56 * matrix R, and the rows ABOVE the diagonal are the Householder reflector vectors
57 * from which an explicit form of Q can be recomputed if desired.</p>
58 */
59 private T[][] qrt;
60 /** The diagonal elements of R. */
61 private T[] rDiag;
62 /** Cached value of Q. */
63 private FieldMatrix<T> cachedQ;
64 /** Cached value of QT. */
65 private FieldMatrix<T> cachedQT;
66 /** Cached value of R. */
67 private FieldMatrix<T> cachedR;
68 /** Cached value of H. */
69 private FieldMatrix<T> cachedH;
70 /** Singularity threshold. */
71 private final T threshold;
72 /** checker for zero. */
73 private final Predicate<T> zeroChecker;
74
75 /**
76 * Calculates the QR-decomposition of the given matrix.
77 * The singularity threshold defaults to zero.
78 *
79 * @param matrix The matrix to decompose.
80 *
81 * @see #FieldQRDecomposition(FieldMatrix, CalculusFieldElement)
82 */
83 public FieldQRDecomposition(FieldMatrix<T> matrix) {
84 this(matrix, matrix.getField().getZero());
85 }
86
87 /**
88 * Calculates the QR-decomposition of the given matrix.
89 *
90 * @param matrix The matrix to decompose.
91 * @param threshold Singularity threshold.
92 */
93 public FieldQRDecomposition(FieldMatrix<T> matrix, T threshold) {
94 this(matrix, threshold, FieldElement::isZero);
95 }
96
97 /**
98 * Calculates the QR-decomposition of the given matrix.
99 *
100 * @param matrix The matrix to decompose.
101 * @param threshold Singularity threshold.
102 * @param zeroChecker checker for zero
103 */
104 public FieldQRDecomposition(FieldMatrix<T> matrix, T threshold, Predicate<T> zeroChecker) {
105 this.threshold = threshold;
106 this.zeroChecker = zeroChecker;
107
108 final int m = matrix.getRowDimension();
109 final int n = matrix.getColumnDimension();
110 qrt = matrix.transpose().getData();
111 rDiag = MathArrays.buildArray(threshold.getField(),FastMath.min(m, n));
112 cachedQ = null;
113 cachedQT = null;
114 cachedR = null;
115 cachedH = null;
116
117 decompose(qrt);
118
119 }
120
121 /** Decompose matrix.
122 * @param matrix transposed matrix
123 */
124 protected void decompose(T[][] matrix) {
125 for (int minor = 0; minor < FastMath.min(matrix.length, matrix[0].length); minor++) {
126 performHouseholderReflection(minor, matrix);
127 }
128 }
129
130 /** Perform Householder reflection for a minor A(minor, minor) of A.
131 * @param minor minor index
132 * @param matrix transposed matrix
133 */
134 protected void performHouseholderReflection(int minor, T[][] matrix) {
135
136 final T[] qrtMinor = matrix[minor];
137 final T zero = threshold.getField().getZero();
138 /*
139 * Let x be the first column of the minor, and a^2 = |x|^2.
140 * x will be in the positions qr[minor][minor] through qr[m][minor].
141 * The first column of the transformed minor will be (a,0,0,..)'
142 * The sign of a is chosen to be opposite to the sign of the first
143 * component of x. Let's find a:
144 */
145 T xNormSqr = zero;
146 for (int row = minor; row < qrtMinor.length; row++) {
147 final T c = qrtMinor[row];
148 xNormSqr = xNormSqr.add(c.square());
149 }
150 final T a = (qrtMinor[minor].getReal() > 0) ? xNormSqr.sqrt().negate() : xNormSqr.sqrt();
151 rDiag[minor] = a;
152
153 if (!zeroChecker.test(a)) {
154
155 /*
156 * Calculate the normalized reflection vector v and transform
157 * the first column. We know the norm of v beforehand: v = x-ae
158 * so |v|^2 = <x-ae,x-ae> = <x,x>-2a<x,e>+a^2<e,e> =
159 * a^2+a^2-2a<x,e> = 2a*(a - <x,e>).
160 * Here <x, e> is now qr[minor][minor].
161 * v = x-ae is stored in the column at qr:
162 */
163 qrtMinor[minor] = qrtMinor[minor].subtract(a); // now |v|^2 = -2a*(qr[minor][minor])
164
165 /*
166 * Transform the rest of the columns of the minor:
167 * They will be transformed by the matrix H = I-2vv'/|v|^2.
168 * If x is a column vector of the minor, then
169 * Hx = (I-2vv'/|v|^2)x = x-2vv'x/|v|^2 = x - 2<x,v>/|v|^2 v.
170 * Therefore the transformation is easily calculated by
171 * subtracting the column vector (2<x,v>/|v|^2)v from x.
172 *
173 * Let 2<x,v>/|v|^2 = alpha. From above we have
174 * |v|^2 = -2a*(qr[minor][minor]), so
175 * alpha = -<x,v>/(a*qr[minor][minor])
176 */
177 for (int col = minor+1; col < matrix.length; col++) {
178 final T[] qrtCol = matrix[col];
179 T alpha = zero;
180 for (int row = minor; row < qrtCol.length; row++) {
181 alpha = alpha.subtract(qrtCol[row].multiply(qrtMinor[row]));
182 }
183 alpha = alpha.divide(a.multiply(qrtMinor[minor]));
184
185 // Subtract the column vector alpha*v from x.
186 for (int row = minor; row < qrtCol.length; row++) {
187 qrtCol[row] = qrtCol[row].subtract(alpha.multiply(qrtMinor[row]));
188 }
189 }
190 }
191 }
192
193
194 /**
195 * Returns the matrix R of the decomposition.
196 * <p>R is an upper-triangular matrix</p>
197 * @return the R matrix
198 */
199 public FieldMatrix<T> getR() {
200
201 if (cachedR == null) {
202
203 // R is supposed to be m x n
204 final int n = qrt.length;
205 final int m = qrt[0].length;
206 T[][] ra = MathArrays.buildArray(threshold.getField(), m, n);
207 // copy the diagonal from rDiag and the upper triangle of qr
208 for (int row = FastMath.min(m, n) - 1; row >= 0; row--) {
209 ra[row][row] = rDiag[row];
210 for (int col = row + 1; col < n; col++) {
211 ra[row][col] = qrt[col][row];
212 }
213 }
214 cachedR = MatrixUtils.createFieldMatrix(ra);
215 }
216
217 // return the cached matrix
218 return cachedR;
219 }
220
221 /**
222 * Returns the matrix Q of the decomposition.
223 * <p>Q is an orthogonal matrix</p>
224 * @return the Q matrix
225 */
226 public FieldMatrix<T> getQ() {
227 if (cachedQ == null) {
228 cachedQ = getQT().transpose();
229 }
230 return cachedQ;
231 }
232
233 /**
234 * Returns the transpose of the matrix Q of the decomposition.
235 * <p>Q is an orthogonal matrix</p>
236 * @return the transpose of the Q matrix, Q<sup>T</sup>
237 */
238 public FieldMatrix<T> getQT() {
239 if (cachedQT == null) {
240
241 // QT is supposed to be m x m
242 final int n = qrt.length;
243 final int m = qrt[0].length;
244 T[][] qta = MathArrays.buildArray(threshold.getField(), m, m);
245
246 /*
247 * Q = Q1 Q2 ... Q_m, so Q is formed by first constructing Q_m and then
248 * applying the Householder transformations Q_(m-1),Q_(m-2),...,Q1 in
249 * succession to the result
250 */
251 for (int minor = m - 1; minor >= FastMath.min(m, n); minor--) {
252 qta[minor][minor] = threshold.getField().getOne();
253 }
254
255 for (int minor = FastMath.min(m, n)-1; minor >= 0; minor--){
256 final T[] qrtMinor = qrt[minor];
257 qta[minor][minor] = threshold.getField().getOne();
258 if (!qrtMinor[minor].isZero()) {
259 for (int col = minor; col < m; col++) {
260 T alpha = threshold.getField().getZero();
261 for (int row = minor; row < m; row++) {
262 alpha = alpha.subtract(qta[col][row].multiply(qrtMinor[row]));
263 }
264 alpha = alpha.divide(rDiag[minor].multiply(qrtMinor[minor]));
265
266 for (int row = minor; row < m; row++) {
267 qta[col][row] = qta[col][row].add(alpha.negate().multiply(qrtMinor[row]));
268 }
269 }
270 }
271 }
272 cachedQT = MatrixUtils.createFieldMatrix(qta);
273 }
274
275 // return the cached matrix
276 return cachedQT;
277 }
278
279 /**
280 * Returns the Householder reflector vectors.
281 * <p>H is a lower trapezoidal matrix whose columns represent
282 * each successive Householder reflector vector. This matrix is used
283 * to compute Q.</p>
284 * @return a matrix containing the Householder reflector vectors
285 */
286 public FieldMatrix<T> getH() {
287 if (cachedH == null) {
288
289 final int n = qrt.length;
290 final int m = qrt[0].length;
291 T[][] ha = MathArrays.buildArray(threshold.getField(), m, n);
292 for (int i = 0; i < m; ++i) {
293 for (int j = 0; j < FastMath.min(i + 1, n); ++j) {
294 ha[i][j] = qrt[j][i].divide(rDiag[j].negate());
295 }
296 }
297 cachedH = MatrixUtils.createFieldMatrix(ha);
298 }
299
300 // return the cached matrix
301 return cachedH;
302 }
303
304 /**
305 * Get a solver for finding the A × X = B solution in least square sense.
306 * <p>
307 * Least Square sense means a solver can be computed for an overdetermined system,
308 * (i.e. a system with more equations than unknowns, which corresponds to a tall A
309 * matrix with more rows than columns). In any case, if the matrix is singular
310 * within the tolerance set at {@link #FieldQRDecomposition(FieldMatrix,
311 * CalculusFieldElement) construction}, an error will be triggered when
312 * the {@link DecompositionSolver#solve(RealVector) solve} method will be called.
313 * </p>
314 * @return a solver
315 */
316 public FieldDecompositionSolver<T> getSolver() {
317 return new FieldSolver();
318 }
319
320 /**
321 * Specialized solver.
322 */
323 private class FieldSolver implements FieldDecompositionSolver<T>{
324
325 /** {@inheritDoc} */
326 @Override
327 public boolean isNonSingular() {
328 return !checkSingular(rDiag, threshold, false);
329 }
330
331 /** {@inheritDoc} */
332 @Override
333 public FieldVector<T> solve(FieldVector<T> b) {
334 final int n = qrt.length;
335 final int m = qrt[0].length;
336 if (b.getDimension() != m) {
337 throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH,
338 b.getDimension(), m);
339 }
340 checkSingular(rDiag, threshold, true);
341
342 final T[] x =MathArrays.buildArray(threshold.getField(),n);
343 final T[] y = b.toArray();
344
345 // apply Householder transforms to solve Q.y = b
346 for (int minor = 0; minor < FastMath.min(m, n); minor++) {
347
348 final T[] qrtMinor = qrt[minor];
349 T dotProduct = threshold.getField().getZero();
350 for (int row = minor; row < m; row++) {
351 dotProduct = dotProduct.add(y[row].multiply(qrtMinor[row]));
352 }
353 dotProduct = dotProduct.divide(rDiag[minor].multiply(qrtMinor[minor]));
354
355 for (int row = minor; row < m; row++) {
356 y[row] = y[row].add(dotProduct.multiply(qrtMinor[row]));
357 }
358 }
359
360 // solve triangular system R.x = y
361 for (int row = rDiag.length - 1; row >= 0; --row) {
362 y[row] = y[row].divide(rDiag[row]);
363 final T yRow = y[row];
364 final T[] qrtRow = qrt[row];
365 x[row] = yRow;
366 for (int i = 0; i < row; i++) {
367 y[i] = y[i].subtract(yRow.multiply(qrtRow[i]));
368 }
369 }
370
371 return new ArrayFieldVector<>(x, false);
372 }
373
374 /** {@inheritDoc} */
375 @Override
376 public FieldMatrix<T> solve(FieldMatrix<T> b) {
377 final int n = qrt.length;
378 final int m = qrt[0].length;
379 if (b.getRowDimension() != m) {
380 throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH,
381 b.getRowDimension(), m);
382 }
383 checkSingular(rDiag, threshold, true);
384
385 final int columns = b.getColumnDimension();
386 final int blockSize = BlockFieldMatrix.BLOCK_SIZE;
387 final int cBlocks = (columns + blockSize - 1) / blockSize;
388 final T[][] xBlocks = BlockFieldMatrix.createBlocksLayout(threshold.getField(),n, columns);
389 final T[][] y = MathArrays.buildArray(threshold.getField(), b.getRowDimension(), blockSize);
390 final T[] alpha = MathArrays.buildArray(threshold.getField(), blockSize);
391
392 for (int kBlock = 0; kBlock < cBlocks; ++kBlock) {
393 final int kStart = kBlock * blockSize;
394 final int kEnd = FastMath.min(kStart + blockSize, columns);
395 final int kWidth = kEnd - kStart;
396
397 // get the right hand side vector
398 b.copySubMatrix(0, m - 1, kStart, kEnd - 1, y);
399
400 // apply Householder transforms to solve Q.y = b
401 for (int minor = 0; minor < FastMath.min(m, n); minor++) {
402 final T[] qrtMinor = qrt[minor];
403 final T factor = rDiag[minor].multiply(qrtMinor[minor]).reciprocal();
404
405 Arrays.fill(alpha, 0, kWidth, threshold.getField().getZero());
406 for (int row = minor; row < m; ++row) {
407 final T d = qrtMinor[row];
408 final T[] yRow = y[row];
409 for (int k = 0; k < kWidth; ++k) {
410 alpha[k] = alpha[k].add(d.multiply(yRow[k]));
411 }
412 }
413
414 for (int k = 0; k < kWidth; ++k) {
415 alpha[k] = alpha[k].multiply(factor);
416 }
417
418 for (int row = minor; row < m; ++row) {
419 final T d = qrtMinor[row];
420 final T[] yRow = y[row];
421 for (int k = 0; k < kWidth; ++k) {
422 yRow[k] = yRow[k].add(alpha[k].multiply(d));
423 }
424 }
425 }
426
427 // solve triangular system R.x = y
428 for (int j = rDiag.length - 1; j >= 0; --j) {
429 final int jBlock = j / blockSize;
430 final int jStart = jBlock * blockSize;
431 final T factor = rDiag[j].reciprocal();
432 final T[] yJ = y[j];
433 final T[] xBlock = xBlocks[jBlock * cBlocks + kBlock];
434 int index = (j - jStart) * kWidth;
435 for (int k = 0; k < kWidth; ++k) {
436 yJ[k] =yJ[k].multiply(factor);
437 xBlock[index++] = yJ[k];
438 }
439
440 final T[] qrtJ = qrt[j];
441 for (int i = 0; i < j; ++i) {
442 final T rIJ = qrtJ[i];
443 final T[] yI = y[i];
444 for (int k = 0; k < kWidth; ++k) {
445 yI[k] = yI[k].subtract(yJ[k].multiply(rIJ));
446 }
447 }
448 }
449 }
450
451 return new BlockFieldMatrix<>(n, columns, xBlocks, false);
452 }
453
454 /**
455 * {@inheritDoc}
456 * @throws MathIllegalArgumentException if the decomposed matrix is singular.
457 */
458 @Override
459 public FieldMatrix<T> getInverse() {
460 return solve(MatrixUtils.createFieldIdentityMatrix(threshold.getField(), qrt[0].length));
461 }
462
463 /**
464 * Check singularity.
465 *
466 * @param diag Diagonal elements of the R matrix.
467 * @param min Singularity threshold.
468 * @param raise Whether to raise a {@link MathIllegalArgumentException}
469 * if any element of the diagonal fails the check.
470 * @return {@code true} if any element of the diagonal is smaller
471 * or equal to {@code min}.
472 * @throws MathIllegalArgumentException if the matrix is singular and
473 * {@code raise} is {@code true}.
474 */
475 private boolean checkSingular(T[] diag,
476 T min,
477 boolean raise) {
478 for (final T d : diag) {
479 if (FastMath.abs(d.getReal()) <= min.getReal()) {
480 if (raise) {
481 throw new MathIllegalArgumentException(LocalizedCoreFormats.SINGULAR_MATRIX);
482 } else {
483 return true;
484 }
485 }
486 }
487 return false;
488 }
489
490 /** {@inheritDoc} */
491 @Override
492 public int getRowDimension() {
493 return qrt[0].length;
494 }
495
496 /** {@inheritDoc} */
497 @Override
498 public int getColumnDimension() {
499 return qrt.length;
500 }
501
502 }
503 }