FieldQRDecomposition.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      https://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. /*
  18.  * This is not the original file distributed by the Apache Software Foundation
  19.  * It has been modified by the Hipparchus project
  20.  */

  21. package org.hipparchus.linear;

  22. import java.util.Arrays;
  23. import java.util.function.Predicate;

  24. import org.hipparchus.CalculusFieldElement;
  25. import org.hipparchus.FieldElement;
  26. import org.hipparchus.exception.LocalizedCoreFormats;
  27. import org.hipparchus.exception.MathIllegalArgumentException;
  28. import org.hipparchus.util.FastMath;
  29. import org.hipparchus.util.MathArrays;


  30. /**
  31.  * Calculates the QR-decomposition of a field matrix.
  32.  * <p>The QR-decomposition of a matrix A consists of two matrices Q and R
  33.  * that satisfy: A = QR, Q is orthogonal (Q<sup>T</sup>Q = I), and R is
  34.  * upper triangular. If A is m&times;n, Q is m&times;m and R m&times;n.</p>
  35.  * <p>This class compute the decomposition using Householder reflectors.</p>
  36.  * <p>For efficiency purposes, the decomposition in packed form is transposed.
  37.  * This allows inner loop to iterate inside rows, which is much more cache-efficient
  38.  * in Java.</p>
  39.  * <p>This class is based on the class {@link QRDecomposition}.</p>
  40.  *
  41.  * @param <T> type of the underlying field elements
  42.  * @see <a href="http://mathworld.wolfram.com/QRDecomposition.html">MathWorld</a>
  43.  * @see <a href="http://en.wikipedia.org/wiki/QR_decomposition">Wikipedia</a>
  44.  *
  45.  */
  46. public class FieldQRDecomposition<T extends CalculusFieldElement<T>> {
  47.     /**
  48.      * A packed TRANSPOSED representation of the QR decomposition.
  49.      * <p>The elements BELOW the diagonal are the elements of the UPPER triangular
  50.      * matrix R, and the rows ABOVE the diagonal are the Householder reflector vectors
  51.      * from which an explicit form of Q can be recomputed if desired.</p>
  52.      */
  53.     private T[][] qrt;
  54.     /** The diagonal elements of R. */
  55.     private T[] rDiag;
  56.     /** Cached value of Q. */
  57.     private FieldMatrix<T> cachedQ;
  58.     /** Cached value of QT. */
  59.     private FieldMatrix<T> cachedQT;
  60.     /** Cached value of R. */
  61.     private FieldMatrix<T> cachedR;
  62.     /** Cached value of H. */
  63.     private FieldMatrix<T> cachedH;
  64.     /** Singularity threshold. */
  65.     private final T threshold;
  66.     /** checker for zero. */
  67.     private final Predicate<T> zeroChecker;

  68.     /**
  69.      * Calculates the QR-decomposition of the given matrix.
  70.      * The singularity threshold defaults to zero.
  71.      *
  72.      * @param matrix The matrix to decompose.
  73.      *
  74.      * @see #FieldQRDecomposition(FieldMatrix, CalculusFieldElement)
  75.      */
  76.     public FieldQRDecomposition(FieldMatrix<T> matrix) {
  77.         this(matrix, matrix.getField().getZero());
  78.     }

  79.     /**
  80.      * Calculates the QR-decomposition of the given matrix.
  81.      *
  82.      * @param matrix The matrix to decompose.
  83.      * @param threshold Singularity threshold.
  84.      */
  85.     public FieldQRDecomposition(FieldMatrix<T> matrix, T threshold) {
  86.         this(matrix, threshold, FieldElement::isZero);
  87.     }

  88.     /**
  89.      * Calculates the QR-decomposition of the given matrix.
  90.      *
  91.      * @param matrix The matrix to decompose.
  92.      * @param threshold Singularity threshold.
  93.      * @param zeroChecker checker for zero
  94.      */
  95.     public FieldQRDecomposition(FieldMatrix<T> matrix, T threshold, Predicate<T> zeroChecker) {
  96.         this.threshold   = threshold;
  97.         this.zeroChecker = zeroChecker;

  98.         final int m = matrix.getRowDimension();
  99.         final int n = matrix.getColumnDimension();
  100.         qrt = matrix.transpose().getData();
  101.         rDiag = MathArrays.buildArray(threshold.getField(),FastMath.min(m, n));
  102.         cachedQ  = null;
  103.         cachedQT = null;
  104.         cachedR  = null;
  105.         cachedH  = null;

  106.         decompose(qrt);

  107.     }

  108.     /** Decompose matrix.
  109.      * @param matrix transposed matrix
  110.      */
  111.     protected void decompose(T[][] matrix) {
  112.         for (int minor = 0; minor < FastMath.min(matrix.length, matrix[0].length); minor++) {
  113.             performHouseholderReflection(minor, matrix);
  114.         }
  115.     }

  116.     /** Perform Householder reflection for a minor A(minor, minor) of A.
  117.      * @param minor minor index
  118.      * @param matrix transposed matrix
  119.      */
  120.     protected void performHouseholderReflection(int minor, T[][] matrix) {

  121.         final T[] qrtMinor = matrix[minor];
  122.         final T zero = threshold.getField().getZero();
  123.         /*
  124.          * Let x be the first column of the minor, and a^2 = |x|^2.
  125.          * x will be in the positions qr[minor][minor] through qr[m][minor].
  126.          * The first column of the transformed minor will be (a,0,0,..)'
  127.          * The sign of a is chosen to be opposite to the sign of the first
  128.          * component of x. Let's find a:
  129.          */
  130.         T xNormSqr = zero;
  131.         for (int row = minor; row < qrtMinor.length; row++) {
  132.             final T c = qrtMinor[row];
  133.             xNormSqr = xNormSqr.add(c.square());
  134.         }
  135.         final T a = (qrtMinor[minor].getReal() > 0) ? xNormSqr.sqrt().negate() : xNormSqr.sqrt();
  136.         rDiag[minor] = a;

  137.         if (!zeroChecker.test(a)) {

  138.             /*
  139.              * Calculate the normalized reflection vector v and transform
  140.              * the first column. We know the norm of v beforehand: v = x-ae
  141.              * so |v|^2 = <x-ae,x-ae> = <x,x>-2a<x,e>+a^2<e,e> =
  142.              * a^2+a^2-2a<x,e> = 2a*(a - <x,e>).
  143.              * Here <x, e> is now qr[minor][minor].
  144.              * v = x-ae is stored in the column at qr:
  145.              */
  146.             qrtMinor[minor] = qrtMinor[minor].subtract(a); // now |v|^2 = -2a*(qr[minor][minor])

  147.             /*
  148.              * Transform the rest of the columns of the minor:
  149.              * They will be transformed by the matrix H = I-2vv'/|v|^2.
  150.              * If x is a column vector of the minor, then
  151.              * Hx = (I-2vv'/|v|^2)x = x-2vv'x/|v|^2 = x - 2<x,v>/|v|^2 v.
  152.              * Therefore the transformation is easily calculated by
  153.              * subtracting the column vector (2<x,v>/|v|^2)v from x.
  154.              *
  155.              * Let 2<x,v>/|v|^2 = alpha. From above we have
  156.              * |v|^2 = -2a*(qr[minor][minor]), so
  157.              * alpha = -<x,v>/(a*qr[minor][minor])
  158.              */
  159.             for (int col = minor+1; col < matrix.length; col++) {
  160.                 final T[] qrtCol = matrix[col];
  161.                 T alpha = zero;
  162.                 for (int row = minor; row < qrtCol.length; row++) {
  163.                     alpha = alpha.subtract(qrtCol[row].multiply(qrtMinor[row]));
  164.                 }
  165.                 alpha = alpha.divide(a.multiply(qrtMinor[minor]));

  166.                 // Subtract the column vector alpha*v from x.
  167.                 for (int row = minor; row < qrtCol.length; row++) {
  168.                     qrtCol[row] = qrtCol[row].subtract(alpha.multiply(qrtMinor[row]));
  169.                 }
  170.             }
  171.         }
  172.     }


  173.     /**
  174.      * Returns the matrix R of the decomposition.
  175.      * <p>R is an upper-triangular matrix</p>
  176.      * @return the R matrix
  177.      */
  178.     public FieldMatrix<T> getR() {

  179.         if (cachedR == null) {

  180.             // R is supposed to be m x n
  181.             final int n = qrt.length;
  182.             final int m = qrt[0].length;
  183.             T[][] ra = MathArrays.buildArray(threshold.getField(), m, n);
  184.             // copy the diagonal from rDiag and the upper triangle of qr
  185.             for (int row = FastMath.min(m, n) - 1; row >= 0; row--) {
  186.                 ra[row][row] = rDiag[row];
  187.                 for (int col = row + 1; col < n; col++) {
  188.                     ra[row][col] = qrt[col][row];
  189.                 }
  190.             }
  191.             cachedR = MatrixUtils.createFieldMatrix(ra);
  192.         }

  193.         // return the cached matrix
  194.         return cachedR;
  195.     }

  196.     /**
  197.      * Returns the matrix Q of the decomposition.
  198.      * <p>Q is an orthogonal matrix</p>
  199.      * @return the Q matrix
  200.      */
  201.     public FieldMatrix<T> getQ() {
  202.         if (cachedQ == null) {
  203.             cachedQ = getQT().transpose();
  204.         }
  205.         return cachedQ;
  206.     }

  207.     /**
  208.      * Returns the transpose of the matrix Q of the decomposition.
  209.      * <p>Q is an orthogonal matrix</p>
  210.      * @return the transpose of the Q matrix, Q<sup>T</sup>
  211.      */
  212.     public FieldMatrix<T> getQT() {
  213.         if (cachedQT == null) {

  214.             // QT is supposed to be m x m
  215.             final int n = qrt.length;
  216.             final int m = qrt[0].length;
  217.             T[][] qta = MathArrays.buildArray(threshold.getField(), m, m);

  218.             /*
  219.              * Q = Q1 Q2 ... Q_m, so Q is formed by first constructing Q_m and then
  220.              * applying the Householder transformations Q_(m-1),Q_(m-2),...,Q1 in
  221.              * succession to the result
  222.              */
  223.             for (int minor = m - 1; minor >= FastMath.min(m, n); minor--) {
  224.                 qta[minor][minor] = threshold.getField().getOne();
  225.             }

  226.             for (int minor = FastMath.min(m, n)-1; minor >= 0; minor--){
  227.                 final T[] qrtMinor = qrt[minor];
  228.                 qta[minor][minor] = threshold.getField().getOne();
  229.                 if (!qrtMinor[minor].isZero()) {
  230.                     for (int col = minor; col < m; col++) {
  231.                         T alpha = threshold.getField().getZero();
  232.                         for (int row = minor; row < m; row++) {
  233.                             alpha = alpha.subtract(qta[col][row].multiply(qrtMinor[row]));
  234.                         }
  235.                         alpha = alpha.divide(rDiag[minor].multiply(qrtMinor[minor]));

  236.                         for (int row = minor; row < m; row++) {
  237.                             qta[col][row] = qta[col][row].add(alpha.negate().multiply(qrtMinor[row]));
  238.                         }
  239.                     }
  240.                 }
  241.             }
  242.             cachedQT = MatrixUtils.createFieldMatrix(qta);
  243.         }

  244.         // return the cached matrix
  245.         return cachedQT;
  246.     }

  247.     /**
  248.      * Returns the Householder reflector vectors.
  249.      * <p>H is a lower trapezoidal matrix whose columns represent
  250.      * each successive Householder reflector vector. This matrix is used
  251.      * to compute Q.</p>
  252.      * @return a matrix containing the Householder reflector vectors
  253.      */
  254.     public FieldMatrix<T> getH() {
  255.         if (cachedH == null) {

  256.             final int n = qrt.length;
  257.             final int m = qrt[0].length;
  258.             T[][] ha = MathArrays.buildArray(threshold.getField(), m, n);
  259.             for (int i = 0; i < m; ++i) {
  260.                 for (int j = 0; j < FastMath.min(i + 1, n); ++j) {
  261.                     ha[i][j] = qrt[j][i].divide(rDiag[j].negate());
  262.                 }
  263.             }
  264.             cachedH = MatrixUtils.createFieldMatrix(ha);
  265.         }

  266.         // return the cached matrix
  267.         return cachedH;
  268.     }

  269.     /**
  270.      * Get a solver for finding the A &times; X = B solution in least square sense.
  271.      * <p>
  272.      * Least Square sense means a solver can be computed for an overdetermined system,
  273.      * (i.e. a system with more equations than unknowns, which corresponds to a tall A
  274.      * matrix with more rows than columns). In any case, if the matrix is singular
  275.      * within the tolerance set at {@link #FieldQRDecomposition(FieldMatrix,
  276.      * CalculusFieldElement) construction}, an error will be triggered when
  277.      * the {@link DecompositionSolver#solve(RealVector) solve} method will be called.
  278.      * </p>
  279.      * @return a solver
  280.      */
  281.     public FieldDecompositionSolver<T> getSolver() {
  282.         return new FieldSolver();
  283.     }

  284.     /**
  285.      * Specialized solver.
  286.      */
  287.     private class FieldSolver implements FieldDecompositionSolver<T>{

  288.         /** {@inheritDoc} */
  289.         @Override
  290.         public boolean isNonSingular() {
  291.             return !checkSingular(rDiag, threshold, false);
  292.         }

  293.         /** {@inheritDoc} */
  294.         @Override
  295.         public FieldVector<T> solve(FieldVector<T> b) {
  296.             final int n = qrt.length;
  297.             final int m = qrt[0].length;
  298.             if (b.getDimension() != m) {
  299.                 throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH,
  300.                                                        b.getDimension(), m);
  301.             }
  302.             checkSingular(rDiag, threshold, true);

  303.             final T[] x =MathArrays.buildArray(threshold.getField(),n);
  304.             final T[] y = b.toArray();

  305.             // apply Householder transforms to solve Q.y = b
  306.             for (int minor = 0; minor < FastMath.min(m, n); minor++) {

  307.                 final T[] qrtMinor = qrt[minor];
  308.                 T dotProduct = threshold.getField().getZero();
  309.                 for (int row = minor; row < m; row++) {
  310.                     dotProduct = dotProduct.add(y[row].multiply(qrtMinor[row]));
  311.                 }
  312.                 dotProduct =  dotProduct.divide(rDiag[minor].multiply(qrtMinor[minor]));

  313.                 for (int row = minor; row < m; row++) {
  314.                     y[row] = y[row].add(dotProduct.multiply(qrtMinor[row]));
  315.                 }
  316.             }

  317.             // solve triangular system R.x = y
  318.             for (int row = rDiag.length - 1; row >= 0; --row) {
  319.                 y[row] = y[row].divide(rDiag[row]);
  320.                 final T yRow = y[row];
  321.                 final T[] qrtRow = qrt[row];
  322.                 x[row] = yRow;
  323.                 for (int i = 0; i < row; i++) {
  324.                     y[i] = y[i].subtract(yRow.multiply(qrtRow[i]));
  325.                 }
  326.             }

  327.             return new ArrayFieldVector<>(x, false);
  328.         }

  329.         /** {@inheritDoc} */
  330.         @Override
  331.         public FieldMatrix<T> solve(FieldMatrix<T> b) {
  332.             final int n = qrt.length;
  333.             final int m = qrt[0].length;
  334.             if (b.getRowDimension() != m) {
  335.                 throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH,
  336.                                                        b.getRowDimension(), m);
  337.             }
  338.             checkSingular(rDiag, threshold, true);

  339.             final int columns        = b.getColumnDimension();
  340.             final int blockSize      = BlockFieldMatrix.BLOCK_SIZE;
  341.             final int cBlocks        = (columns + blockSize - 1) / blockSize;
  342.             final T[][] xBlocks = BlockFieldMatrix.createBlocksLayout(threshold.getField(),n, columns);
  343.             final T[][] y       = MathArrays.buildArray(threshold.getField(), b.getRowDimension(), blockSize);
  344.             final T[]   alpha   = MathArrays.buildArray(threshold.getField(), blockSize);

  345.             for (int kBlock = 0; kBlock < cBlocks; ++kBlock) {
  346.                 final int kStart = kBlock * blockSize;
  347.                 final int kEnd   = FastMath.min(kStart + blockSize, columns);
  348.                 final int kWidth = kEnd - kStart;

  349.                 // get the right hand side vector
  350.                 b.copySubMatrix(0, m - 1, kStart, kEnd - 1, y);

  351.                 // apply Householder transforms to solve Q.y = b
  352.                 for (int minor = 0; minor < FastMath.min(m, n); minor++) {
  353.                     final T[] qrtMinor = qrt[minor];
  354.                     final T factor     = rDiag[minor].multiply(qrtMinor[minor]).reciprocal();

  355.                     Arrays.fill(alpha, 0, kWidth, threshold.getField().getZero());
  356.                     for (int row = minor; row < m; ++row) {
  357.                         final T   d    = qrtMinor[row];
  358.                         final T[] yRow = y[row];
  359.                         for (int k = 0; k < kWidth; ++k) {
  360.                             alpha[k] = alpha[k].add(d.multiply(yRow[k]));
  361.                         }
  362.                     }

  363.                     for (int k = 0; k < kWidth; ++k) {
  364.                         alpha[k] = alpha[k].multiply(factor);
  365.                     }

  366.                     for (int row = minor; row < m; ++row) {
  367.                         final T   d    = qrtMinor[row];
  368.                         final T[] yRow = y[row];
  369.                         for (int k = 0; k < kWidth; ++k) {
  370.                             yRow[k] = yRow[k].add(alpha[k].multiply(d));
  371.                         }
  372.                     }
  373.                 }

  374.                 // solve triangular system R.x = y
  375.                 for (int j = rDiag.length - 1; j >= 0; --j) {
  376.                     final int      jBlock = j / blockSize;
  377.                     final int      jStart = jBlock * blockSize;
  378.                     final T   factor = rDiag[j].reciprocal();
  379.                     final T[] yJ     = y[j];
  380.                     final T[] xBlock = xBlocks[jBlock * cBlocks + kBlock];
  381.                     int index = (j - jStart) * kWidth;
  382.                     for (int k = 0; k < kWidth; ++k) {
  383.                         yJ[k]           =yJ[k].multiply(factor);
  384.                         xBlock[index++] = yJ[k];
  385.                     }

  386.                     final T[] qrtJ = qrt[j];
  387.                     for (int i = 0; i < j; ++i) {
  388.                         final T rIJ  = qrtJ[i];
  389.                         final T[] yI = y[i];
  390.                         for (int k = 0; k < kWidth; ++k) {
  391.                             yI[k] = yI[k].subtract(yJ[k].multiply(rIJ));
  392.                         }
  393.                     }
  394.                 }
  395.             }

  396.             return new BlockFieldMatrix<>(n, columns, xBlocks, false);
  397.         }

  398.         /**
  399.          * {@inheritDoc}
  400.          * @throws MathIllegalArgumentException if the decomposed matrix is singular.
  401.          */
  402.         @Override
  403.         public FieldMatrix<T> getInverse() {
  404.             return solve(MatrixUtils.createFieldIdentityMatrix(threshold.getField(), qrt[0].length));
  405.         }

  406.         /**
  407.          * Check singularity.
  408.          *
  409.          * @param diag Diagonal elements of the R matrix.
  410.          * @param min Singularity threshold.
  411.          * @param raise Whether to raise a {@link MathIllegalArgumentException}
  412.          * if any element of the diagonal fails the check.
  413.          * @return {@code true} if any element of the diagonal is smaller
  414.          * or equal to {@code min}.
  415.          * @throws MathIllegalArgumentException if the matrix is singular and
  416.          * {@code raise} is {@code true}.
  417.          */
  418.         private boolean checkSingular(T[] diag,
  419.                                              T min,
  420.                                              boolean raise) {
  421.             for (final T d : diag) {
  422.                 if (FastMath.abs(d.getReal()) <= min.getReal()) {
  423.                     if (raise) {
  424.                         throw new MathIllegalArgumentException(LocalizedCoreFormats.SINGULAR_MATRIX);
  425.                     } else {
  426.                         return true;
  427.                     }
  428.                 }
  429.             }
  430.             return false;
  431.         }

  432.         /** {@inheritDoc} */
  433.         @Override
  434.         public int getRowDimension() {
  435.             return qrt[0].length;
  436.         }

  437.         /** {@inheritDoc} */
  438.         @Override
  439.         public int getColumnDimension() {
  440.             return qrt.length;
  441.         }

  442.     }
  443. }