View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  /*
19   * This is not the original file distributed by the Apache Software Foundation
20   * It has been modified by the Hipparchus project
21   */
22  
23  package org.hipparchus.linear;
24  
25  import java.util.Arrays;
26  import java.util.function.Predicate;
27  
28  import org.hipparchus.CalculusFieldElement;
29  import org.hipparchus.exception.LocalizedCoreFormats;
30  import org.hipparchus.exception.MathIllegalArgumentException;
31  import org.hipparchus.util.FastMath;
32  import org.hipparchus.util.MathArrays;
33  
34  
35  /**
36   * Calculates the QR-decomposition of a field matrix.
37   * <p>The QR-decomposition of a matrix A consists of two matrices Q and R
38   * that satisfy: A = QR, Q is orthogonal (Q<sup>T</sup>Q = I), and R is
39   * upper triangular. If A is m&times;n, Q is m&times;m and R m&times;n.</p>
40   * <p>This class compute the decomposition using Householder reflectors.</p>
41   * <p>For efficiency purposes, the decomposition in packed form is transposed.
42   * This allows inner loop to iterate inside rows, which is much more cache-efficient
43   * in Java.</p>
44   * <p>This class is based on the class {@link QRDecomposition}.</p>
45   *
46   * @param <T> type of the underlying field elements
47   * @see <a href="http://mathworld.wolfram.com/QRDecomposition.html">MathWorld</a>
48   * @see <a href="http://en.wikipedia.org/wiki/QR_decomposition">Wikipedia</a>
49   *
50   */
51  public class FieldQRDecomposition<T extends CalculusFieldElement<T>> {
52      /**
53       * A packed TRANSPOSED representation of the QR decomposition.
54       * <p>The elements BELOW the diagonal are the elements of the UPPER triangular
55       * matrix R, and the rows ABOVE the diagonal are the Householder reflector vectors
56       * from which an explicit form of Q can be recomputed if desired.</p>
57       */
58      private T[][] qrt;
59      /** The diagonal elements of R. */
60      private T[] rDiag;
61      /** Cached value of Q. */
62      private FieldMatrix<T> cachedQ;
63      /** Cached value of QT. */
64      private FieldMatrix<T> cachedQT;
65      /** Cached value of R. */
66      private FieldMatrix<T> cachedR;
67      /** Cached value of H. */
68      private FieldMatrix<T> cachedH;
69      /** Singularity threshold. */
70      private final T threshold;
71      /** checker for zero. */
72      private final Predicate<T> zeroChecker;
73  
74      /**
75       * Calculates the QR-decomposition of the given matrix.
76       * The singularity threshold defaults to zero.
77       *
78       * @param matrix The matrix to decompose.
79       *
80       * @see #FieldQRDecomposition(FieldMatrix, CalculusFieldElement)
81       */
82      public FieldQRDecomposition(FieldMatrix<T> matrix) {
83          this(matrix, matrix.getField().getZero());
84      }
85  
86      /**
87       * Calculates the QR-decomposition of the given matrix.
88       *
89       * @param matrix The matrix to decompose.
90       * @param threshold Singularity threshold.
91       */
92      public FieldQRDecomposition(FieldMatrix<T> matrix, T threshold) {
93          this(matrix, threshold, e -> e.isZero());
94      }
95  
96      /**
97       * Calculates the QR-decomposition of the given matrix.
98       *
99       * @param matrix The matrix to decompose.
100      * @param threshold Singularity threshold.
101      * @param zeroChecker checker for zero
102      */
103     public FieldQRDecomposition(FieldMatrix<T> matrix, T threshold, Predicate<T> zeroChecker) {
104         this.threshold   = threshold;
105         this.zeroChecker = zeroChecker;
106 
107         final int m = matrix.getRowDimension();
108         final int n = matrix.getColumnDimension();
109         qrt = matrix.transpose().getData();
110         rDiag = MathArrays.buildArray(threshold.getField(),FastMath.min(m, n));
111         cachedQ  = null;
112         cachedQT = null;
113         cachedR  = null;
114         cachedH  = null;
115 
116         decompose(qrt);
117 
118     }
119 
120     /** Decompose matrix.
121      * @param matrix transposed matrix
122      */
123     protected void decompose(T[][] matrix) {
124         for (int minor = 0; minor < FastMath.min(matrix.length, matrix[0].length); minor++) {
125             performHouseholderReflection(minor, matrix);
126         }
127     }
128 
129     /** Perform Householder reflection for a minor A(minor, minor) of A.
130      * @param minor minor index
131      * @param matrix transposed matrix
132      */
133     protected void performHouseholderReflection(int minor, T[][] matrix) {
134 
135         final T[] qrtMinor = matrix[minor];
136         final T zero = threshold.getField().getZero();
137         /*
138          * Let x be the first column of the minor, and a^2 = |x|^2.
139          * x will be in the positions qr[minor][minor] through qr[m][minor].
140          * The first column of the transformed minor will be (a,0,0,..)'
141          * The sign of a is chosen to be opposite to the sign of the first
142          * component of x. Let's find a:
143          */
144         T xNormSqr = zero;
145         for (int row = minor; row < qrtMinor.length; row++) {
146             final T c = qrtMinor[row];
147             xNormSqr = xNormSqr.add(c.square());
148         }
149         final T a = (qrtMinor[minor].getReal() > 0) ? xNormSqr.sqrt().negate() : xNormSqr.sqrt();
150         rDiag[minor] = a;
151 
152         if (!zeroChecker.test(a)) {
153 
154             /*
155              * Calculate the normalized reflection vector v and transform
156              * the first column. We know the norm of v beforehand: v = x-ae
157              * so |v|^2 = <x-ae,x-ae> = <x,x>-2a<x,e>+a^2<e,e> =
158              * a^2+a^2-2a<x,e> = 2a*(a - <x,e>).
159              * Here <x, e> is now qr[minor][minor].
160              * v = x-ae is stored in the column at qr:
161              */
162             qrtMinor[minor] = qrtMinor[minor].subtract(a); // now |v|^2 = -2a*(qr[minor][minor])
163 
164             /*
165              * Transform the rest of the columns of the minor:
166              * They will be transformed by the matrix H = I-2vv'/|v|^2.
167              * If x is a column vector of the minor, then
168              * Hx = (I-2vv'/|v|^2)x = x-2vv'x/|v|^2 = x - 2<x,v>/|v|^2 v.
169              * Therefore the transformation is easily calculated by
170              * subtracting the column vector (2<x,v>/|v|^2)v from x.
171              *
172              * Let 2<x,v>/|v|^2 = alpha. From above we have
173              * |v|^2 = -2a*(qr[minor][minor]), so
174              * alpha = -<x,v>/(a*qr[minor][minor])
175              */
176             for (int col = minor+1; col < matrix.length; col++) {
177                 final T[] qrtCol = matrix[col];
178                 T alpha = zero;
179                 for (int row = minor; row < qrtCol.length; row++) {
180                     alpha = alpha.subtract(qrtCol[row].multiply(qrtMinor[row]));
181                 }
182                 alpha = alpha.divide(a.multiply(qrtMinor[minor]));
183 
184                 // Subtract the column vector alpha*v from x.
185                 for (int row = minor; row < qrtCol.length; row++) {
186                     qrtCol[row] = qrtCol[row].subtract(alpha.multiply(qrtMinor[row]));
187                 }
188             }
189         }
190     }
191 
192 
193     /**
194      * Returns the matrix R of the decomposition.
195      * <p>R is an upper-triangular matrix</p>
196      * @return the R matrix
197      */
198     public FieldMatrix<T> getR() {
199 
200         if (cachedR == null) {
201 
202             // R is supposed to be m x n
203             final int n = qrt.length;
204             final int m = qrt[0].length;
205             T[][] ra = MathArrays.buildArray(threshold.getField(), m, n);
206             // copy the diagonal from rDiag and the upper triangle of qr
207             for (int row = FastMath.min(m, n) - 1; row >= 0; row--) {
208                 ra[row][row] = rDiag[row];
209                 for (int col = row + 1; col < n; col++) {
210                     ra[row][col] = qrt[col][row];
211                 }
212             }
213             cachedR = MatrixUtils.createFieldMatrix(ra);
214         }
215 
216         // return the cached matrix
217         return cachedR;
218     }
219 
220     /**
221      * Returns the matrix Q of the decomposition.
222      * <p>Q is an orthogonal matrix</p>
223      * @return the Q matrix
224      */
225     public FieldMatrix<T> getQ() {
226         if (cachedQ == null) {
227             cachedQ = getQT().transpose();
228         }
229         return cachedQ;
230     }
231 
232     /**
233      * Returns the transpose of the matrix Q of the decomposition.
234      * <p>Q is an orthogonal matrix</p>
235      * @return the transpose of the Q matrix, Q<sup>T</sup>
236      */
237     public FieldMatrix<T> getQT() {
238         if (cachedQT == null) {
239 
240             // QT is supposed to be m x m
241             final int n = qrt.length;
242             final int m = qrt[0].length;
243             T[][] qta = MathArrays.buildArray(threshold.getField(), m, m);
244 
245             /*
246              * Q = Q1 Q2 ... Q_m, so Q is formed by first constructing Q_m and then
247              * applying the Householder transformations Q_(m-1),Q_(m-2),...,Q1 in
248              * succession to the result
249              */
250             for (int minor = m - 1; minor >= FastMath.min(m, n); minor--) {
251                 qta[minor][minor] = threshold.getField().getOne();
252             }
253 
254             for (int minor = FastMath.min(m, n)-1; minor >= 0; minor--){
255                 final T[] qrtMinor = qrt[minor];
256                 qta[minor][minor] = threshold.getField().getOne();
257                 if (!qrtMinor[minor].isZero()) {
258                     for (int col = minor; col < m; col++) {
259                         T alpha = threshold.getField().getZero();
260                         for (int row = minor; row < m; row++) {
261                             alpha = alpha.subtract(qta[col][row].multiply(qrtMinor[row]));
262                         }
263                         alpha = alpha.divide(rDiag[minor].multiply(qrtMinor[minor]));
264 
265                         for (int row = minor; row < m; row++) {
266                             qta[col][row] = qta[col][row].add(alpha.negate().multiply(qrtMinor[row]));
267                         }
268                     }
269                 }
270             }
271             cachedQT = MatrixUtils.createFieldMatrix(qta);
272         }
273 
274         // return the cached matrix
275         return cachedQT;
276     }
277 
278     /**
279      * Returns the Householder reflector vectors.
280      * <p>H is a lower trapezoidal matrix whose columns represent
281      * each successive Householder reflector vector. This matrix is used
282      * to compute Q.</p>
283      * @return a matrix containing the Householder reflector vectors
284      */
285     public FieldMatrix<T> getH() {
286         if (cachedH == null) {
287 
288             final int n = qrt.length;
289             final int m = qrt[0].length;
290             T[][] ha = MathArrays.buildArray(threshold.getField(), m, n);
291             for (int i = 0; i < m; ++i) {
292                 for (int j = 0; j < FastMath.min(i + 1, n); ++j) {
293                     ha[i][j] = qrt[j][i].divide(rDiag[j].negate());
294                 }
295             }
296             cachedH = MatrixUtils.createFieldMatrix(ha);
297         }
298 
299         // return the cached matrix
300         return cachedH;
301     }
302 
303     /**
304      * Get a solver for finding the A &times; X = B solution in least square sense.
305      * <p>
306      * Least Square sense means a solver can be computed for an overdetermined system,
307      * (i.e. a system with more equations than unknowns, which corresponds to a tall A
308      * matrix with more rows than columns). In any case, if the matrix is singular
309      * within the tolerance set at {@link #FieldQRDecomposition(FieldMatrix,
310      * CalculusFieldElement) construction}, an error will be triggered when
311      * the {@link DecompositionSolver#solve(RealVector) solve} method will be called.
312      * </p>
313      * @return a solver
314      */
315     public FieldDecompositionSolver<T> getSolver() {
316         return new FieldSolver();
317     }
318 
319     /**
320      * Specialized solver.
321      */
322     private class FieldSolver implements FieldDecompositionSolver<T>{
323 
324         /** {@inheritDoc} */
325         @Override
326         public boolean isNonSingular() {
327             return !checkSingular(rDiag, threshold, false);
328         }
329 
330         /** {@inheritDoc} */
331         @Override
332         public FieldVector<T> solve(FieldVector<T> b) {
333             final int n = qrt.length;
334             final int m = qrt[0].length;
335             if (b.getDimension() != m) {
336                 throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH,
337                                                        b.getDimension(), m);
338             }
339             checkSingular(rDiag, threshold, true);
340 
341             final T[] x =MathArrays.buildArray(threshold.getField(),n);
342             final T[] y = b.toArray();
343 
344             // apply Householder transforms to solve Q.y = b
345             for (int minor = 0; minor < FastMath.min(m, n); minor++) {
346 
347                 final T[] qrtMinor = qrt[minor];
348                 T dotProduct = threshold.getField().getZero();
349                 for (int row = minor; row < m; row++) {
350                     dotProduct = dotProduct.add(y[row].multiply(qrtMinor[row]));
351                 }
352                 dotProduct =  dotProduct.divide(rDiag[minor].multiply(qrtMinor[minor]));
353 
354                 for (int row = minor; row < m; row++) {
355                     y[row] = y[row].add(dotProduct.multiply(qrtMinor[row]));
356                 }
357             }
358 
359             // solve triangular system R.x = y
360             for (int row = rDiag.length - 1; row >= 0; --row) {
361                 y[row] = y[row].divide(rDiag[row]);
362                 final T yRow = y[row];
363                 final T[] qrtRow = qrt[row];
364                 x[row] = yRow;
365                 for (int i = 0; i < row; i++) {
366                     y[i] = y[i].subtract(yRow.multiply(qrtRow[i]));
367                 }
368             }
369 
370             return new ArrayFieldVector<T>(x, false);
371         }
372 
373         /** {@inheritDoc} */
374         @Override
375         public FieldMatrix<T> solve(FieldMatrix<T> b) {
376             final int n = qrt.length;
377             final int m = qrt[0].length;
378             if (b.getRowDimension() != m) {
379                 throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH,
380                                                        b.getRowDimension(), m);
381             }
382             checkSingular(rDiag, threshold, true);
383 
384             final int columns        = b.getColumnDimension();
385             final int blockSize      = BlockFieldMatrix.BLOCK_SIZE;
386             final int cBlocks        = (columns + blockSize - 1) / blockSize;
387             final T[][] xBlocks = BlockFieldMatrix.createBlocksLayout(threshold.getField(),n, columns);
388             final T[][] y       = MathArrays.buildArray(threshold.getField(), b.getRowDimension(), blockSize);
389             final T[]   alpha   = MathArrays.buildArray(threshold.getField(), blockSize);
390 
391             for (int kBlock = 0; kBlock < cBlocks; ++kBlock) {
392                 final int kStart = kBlock * blockSize;
393                 final int kEnd   = FastMath.min(kStart + blockSize, columns);
394                 final int kWidth = kEnd - kStart;
395 
396                 // get the right hand side vector
397                 b.copySubMatrix(0, m - 1, kStart, kEnd - 1, y);
398 
399                 // apply Householder transforms to solve Q.y = b
400                 for (int minor = 0; minor < FastMath.min(m, n); minor++) {
401                     final T[] qrtMinor = qrt[minor];
402                     final T factor     = rDiag[minor].multiply(qrtMinor[minor]).reciprocal();
403 
404                     Arrays.fill(alpha, 0, kWidth, threshold.getField().getZero());
405                     for (int row = minor; row < m; ++row) {
406                         final T   d    = qrtMinor[row];
407                         final T[] yRow = y[row];
408                         for (int k = 0; k < kWidth; ++k) {
409                             alpha[k] = alpha[k].add(d.multiply(yRow[k]));
410                         }
411                     }
412 
413                     for (int k = 0; k < kWidth; ++k) {
414                         alpha[k] = alpha[k].multiply(factor);
415                     }
416 
417                     for (int row = minor; row < m; ++row) {
418                         final T   d    = qrtMinor[row];
419                         final T[] yRow = y[row];
420                         for (int k = 0; k < kWidth; ++k) {
421                             yRow[k] = yRow[k].add(alpha[k].multiply(d));
422                         }
423                     }
424                 }
425 
426                 // solve triangular system R.x = y
427                 for (int j = rDiag.length - 1; j >= 0; --j) {
428                     final int      jBlock = j / blockSize;
429                     final int      jStart = jBlock * blockSize;
430                     final T   factor = rDiag[j].reciprocal();
431                     final T[] yJ     = y[j];
432                     final T[] xBlock = xBlocks[jBlock * cBlocks + kBlock];
433                     int index = (j - jStart) * kWidth;
434                     for (int k = 0; k < kWidth; ++k) {
435                         yJ[k]           =yJ[k].multiply(factor);
436                         xBlock[index++] = yJ[k];
437                     }
438 
439                     final T[] qrtJ = qrt[j];
440                     for (int i = 0; i < j; ++i) {
441                         final T rIJ  = qrtJ[i];
442                         final T[] yI = y[i];
443                         for (int k = 0; k < kWidth; ++k) {
444                             yI[k] = yI[k].subtract(yJ[k].multiply(rIJ));
445                         }
446                     }
447                 }
448             }
449 
450             return new BlockFieldMatrix<T>(n, columns, xBlocks, false);
451         }
452 
453         /**
454          * {@inheritDoc}
455          * @throws MathIllegalArgumentException if the decomposed matrix is singular.
456          */
457         @Override
458         public FieldMatrix<T> getInverse() {
459             return solve(MatrixUtils.createFieldIdentityMatrix(threshold.getField(), qrt[0].length));
460         }
461 
462         /**
463          * Check singularity.
464          *
465          * @param diag Diagonal elements of the R matrix.
466          * @param min Singularity threshold.
467          * @param raise Whether to raise a {@link MathIllegalArgumentException}
468          * if any element of the diagonal fails the check.
469          * @return {@code true} if any element of the diagonal is smaller
470          * or equal to {@code min}.
471          * @throws MathIllegalArgumentException if the matrix is singular and
472          * {@code raise} is {@code true}.
473          */
474         private boolean checkSingular(T[] diag,
475                                              T min,
476                                              boolean raise) {
477             final int len = diag.length;
478             for (int i = 0; i < len; i++) {
479                 final T d = diag[i];
480                 if (FastMath.abs(d.getReal()) <= min.getReal()) {
481                     if (raise) {
482                         throw new MathIllegalArgumentException(LocalizedCoreFormats.SINGULAR_MATRIX);
483                     } else {
484                         return true;
485                     }
486                 }
487             }
488             return false;
489         }
490 
491         /** {@inheritDoc} */
492         @Override
493         public int getRowDimension() {
494             return qrt[0].length;
495         }
496 
497         /** {@inheritDoc} */
498         @Override
499         public int getColumnDimension() {
500             return qrt.length;
501         }
502 
503     }
504 }