SchurTransformer.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      https://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. /*
  18.  * This is not the original file distributed by the Apache Software Foundation
  19.  * It has been modified by the Hipparchus project
  20.  */

  21. package org.hipparchus.linear;

  22. import org.hipparchus.exception.LocalizedCoreFormats;
  23. import org.hipparchus.exception.MathIllegalArgumentException;
  24. import org.hipparchus.exception.MathIllegalStateException;
  25. import org.hipparchus.util.FastMath;
  26. import org.hipparchus.util.Precision;

  27. /**
  28.  * Class transforming a general real matrix to Schur form.
  29.  * <p>A m &times; m matrix A can be written as the product of three matrices: A = P
  30.  * &times; T &times; P<sup>T</sup> with P an orthogonal matrix and T an quasi-triangular
  31.  * matrix. Both P and T are m &times; m matrices.</p>
  32.  * <p>Transformation to Schur form is often not a goal by itself, but it is an
  33.  * intermediate step in more general decomposition algorithms like
  34.  * {@link EigenDecompositionSymmetric eigen decomposition}. This class is therefore
  35.  * intended for expert use. As a consequence of this explicitly limited scope,
  36.  * many methods directly returns references to internal arrays, not copies.</p>
  37.  * <p>This class is based on the method hqr2 in class EigenvalueDecomposition
  38.  * from the <a href="http://math.nist.gov/javanumerics/jama/">JAMA</a> library.</p>
  39.  *
  40.  * @see <a href="http://mathworld.wolfram.com/SchurDecomposition.html">Schur Decomposition - MathWorld</a>
  41.  * @see <a href="http://en.wikipedia.org/wiki/Schur_decomposition">Schur Decomposition - Wikipedia</a>
  42.  * @see <a href="http://en.wikipedia.org/wiki/Householder_transformation">Householder Transformations</a>
  43.  */
  44. public class SchurTransformer {
  45.     /** Maximum allowed iterations for convergence of the transformation. */
  46.     private static final int MAX_ITERATIONS = 100;

  47.     /** P matrix. */
  48.     private final double[][] matrixP;
  49.     /** T matrix. */
  50.     private final double[][] matrixT;
  51.     /** Cached value of P. */
  52.     private RealMatrix cachedP;
  53.     /** Cached value of T. */
  54.     private RealMatrix cachedT;
  55.     /** Cached value of PT. */
  56.     private RealMatrix cachedPt;

  57.     /** Epsilon criteria. */
  58.     private final double epsilon;

  59.     /**
  60.      * Build the transformation to Schur form of a general real matrix.
  61.      *
  62.      * @param matrix matrix to transform
  63.      * @throws MathIllegalArgumentException if the matrix is not square
  64.      */
  65.     public SchurTransformer(final RealMatrix matrix) {
  66.         /** Epsilon criteria taken from JAMA code (originally was 2^-52). */
  67.         this(matrix, Precision.EPSILON);
  68.     }

  69.     /**
  70.      * Build the transformation to Schur form of a general real matrix.
  71.      *
  72.      * @param matrix matrix to transform
  73.      * @param epsilon convergence criteria
  74.      * @throws MathIllegalArgumentException if the matrix is not square
  75.      * @since 3.0
  76.      */
  77.     public SchurTransformer(final RealMatrix matrix, final double epsilon) {
  78.         if (!matrix.isSquare()) {
  79.             throw new MathIllegalArgumentException(LocalizedCoreFormats.NON_SQUARE_MATRIX,
  80.                                                    matrix.getRowDimension(), matrix.getColumnDimension());
  81.         }
  82.         this.epsilon = epsilon;

  83.         HessenbergTransformer transformer = new HessenbergTransformer(matrix);
  84.         matrixT = transformer.getH().getData();
  85.         matrixP = transformer.getP().getData();
  86.         cachedT = null;
  87.         cachedP = null;
  88.         cachedPt = null;

  89.         // transform matrix
  90.         transform();
  91.     }

  92.     /**
  93.      * Returns the matrix P of the transform.
  94.      * <p>P is an orthogonal matrix, i.e. its inverse is also its transpose.</p>
  95.      *
  96.      * @return the P matrix
  97.      */
  98.     public RealMatrix getP() {
  99.         if (cachedP == null) {
  100.             cachedP = MatrixUtils.createRealMatrix(matrixP);
  101.         }
  102.         return cachedP;
  103.     }

  104.     /**
  105.      * Returns the transpose of the matrix P of the transform.
  106.      * <p>P is an orthogonal matrix, i.e. its inverse is also its transpose.</p>
  107.      *
  108.      * @return the transpose of the P matrix
  109.      */
  110.     public RealMatrix getPT() {
  111.         if (cachedPt == null) {
  112.             cachedPt = getP().transpose();
  113.         }

  114.         // return the cached matrix
  115.         return cachedPt;
  116.     }

  117.     /**
  118.      * Returns the quasi-triangular Schur matrix T of the transform.
  119.      *
  120.      * @return the T matrix
  121.      */
  122.     public RealMatrix getT() {
  123.         if (cachedT == null) {
  124.             cachedT = MatrixUtils.createRealMatrix(matrixT);
  125.         }

  126.         // return the cached matrix
  127.         return cachedT;
  128.     }

  129.     /**
  130.      * Transform original matrix to Schur form.
  131.      * @throws MathIllegalStateException if the transformation does not converge
  132.      */
  133.     private void transform() {
  134.         final int n = matrixT.length;

  135.         // compute matrix norm
  136.         final double norm = getNorm();

  137.         // shift information
  138.         final ShiftInfo shift = new ShiftInfo();

  139.         // Outer loop over eigenvalue index
  140.         int iteration = 0;
  141.         int iu = n - 1;
  142.         while (iu >= 0) {

  143.             // Look for single small sub-diagonal element
  144.             final int il = findSmallSubDiagonalElement(iu, norm);

  145.             // Check for convergence
  146.             if (il == iu) {
  147.                 // One root found
  148.                 matrixT[iu][iu] += shift.exShift;
  149.                 iu--;
  150.                 iteration = 0;
  151.             } else if (il == iu - 1) {
  152.                 // Two roots found
  153.                 double p = (matrixT[iu - 1][iu - 1] - matrixT[iu][iu]) / 2.0;
  154.                 double q = p * p + matrixT[iu][iu - 1] * matrixT[iu - 1][iu];
  155.                 matrixT[iu][iu] += shift.exShift;
  156.                 matrixT[iu - 1][iu - 1] += shift.exShift;

  157.                 if (q >= 0) {
  158.                     double z = FastMath.sqrt(FastMath.abs(q));
  159.                     if (p >= 0) {
  160.                         z = p + z;
  161.                     } else {
  162.                         z = p - z;
  163.                     }
  164.                     final double x = matrixT[iu][iu - 1];
  165.                     final double s = FastMath.abs(x) + FastMath.abs(z);
  166.                     p = x / s;
  167.                     q = z / s;
  168.                     final double r = FastMath.sqrt(p * p + q * q);
  169.                     p /= r;
  170.                     q /= r;

  171.                     // Row modification
  172.                     for (int j = iu - 1; j < n; j++) {
  173.                         z = matrixT[iu - 1][j];
  174.                         matrixT[iu - 1][j] = q * z + p * matrixT[iu][j];
  175.                         matrixT[iu][j] = q * matrixT[iu][j] - p * z;
  176.                     }

  177.                     // Column modification
  178.                     for (int i = 0; i <= iu; i++) {
  179.                         z = matrixT[i][iu - 1];
  180.                         matrixT[i][iu - 1] = q * z + p * matrixT[i][iu];
  181.                         matrixT[i][iu] = q * matrixT[i][iu] - p * z;
  182.                     }

  183.                     // Accumulate transformations
  184.                     for (int i = 0; i <= n - 1; i++) {
  185.                         z = matrixP[i][iu - 1];
  186.                         matrixP[i][iu - 1] = q * z + p * matrixP[i][iu];
  187.                         matrixP[i][iu] = q * matrixP[i][iu] - p * z;
  188.                     }
  189.                 }
  190.                 iu -= 2;
  191.                 iteration = 0;
  192.             } else {
  193.                 // No convergence yet
  194.                 computeShift(il, iu, iteration, shift);

  195.                 // stop transformation after too many iterations
  196.                 ++iteration;
  197.                 if (iteration > MAX_ITERATIONS) {
  198.                     throw new MathIllegalStateException(LocalizedCoreFormats.CONVERGENCE_FAILED,
  199.                                                         MAX_ITERATIONS);
  200.                 }

  201.                 // the initial houseHolder vector for the QR step
  202.                 final double[] hVec = new double[3];

  203.                 final int im = initQRStep(il, iu, shift, hVec);
  204.                 performDoubleQRStep(il, im, iu, shift, hVec, norm);
  205.             }
  206.         }
  207.     }

  208.     /**
  209.      * Computes the L1 norm of the (quasi-)triangular matrix T.
  210.      *
  211.      * @return the L1 norm of matrix T
  212.      */
  213.     private double getNorm() {
  214.         double norm = 0.0;
  215.         for (int i = 0; i < matrixT.length; i++) {
  216.             // as matrix T is (quasi-)triangular, also take the sub-diagonal element into account
  217.             for (int j = FastMath.max(i - 1, 0); j < matrixT.length; j++) {
  218.                 norm += FastMath.abs(matrixT[i][j]);
  219.             }
  220.         }
  221.         return norm;
  222.     }

  223.     /**
  224.      * Find the first small sub-diagonal element and returns its index.
  225.      *
  226.      * @param startIdx the starting index for the search
  227.      * @param norm the L1 norm of the matrix
  228.      * @return the index of the first small sub-diagonal element
  229.      */
  230.     private int findSmallSubDiagonalElement(final int startIdx, final double norm) {
  231.         int l = startIdx;
  232.         while (l > 0) {
  233.             double s = FastMath.abs(matrixT[l - 1][l - 1]) + FastMath.abs(matrixT[l][l]);
  234.             if (s == 0.0) {
  235.                 s = norm;
  236.             }
  237.             if (FastMath.abs(matrixT[l][l - 1]) < epsilon * s) {
  238.                 break;
  239.             }
  240.             l--;
  241.         }
  242.         return l;
  243.     }

  244.     /**
  245.      * Compute the shift for the current iteration.
  246.      *
  247.      * @param l the index of the small sub-diagonal element
  248.      * @param idx the current eigenvalue index
  249.      * @param iteration the current iteration
  250.      * @param shift holder for shift information
  251.      */
  252.     private void computeShift(final int l, final int idx, final int iteration, final ShiftInfo shift) {
  253.         // Form shift
  254.         shift.x = matrixT[idx][idx];
  255.         shift.y = shift.w = 0.0;
  256.         if (l < idx) {
  257.             shift.y = matrixT[idx - 1][idx - 1];
  258.             shift.w = matrixT[idx][idx - 1] * matrixT[idx - 1][idx];
  259.         }

  260.         // Wilkinson's original ad hoc shift
  261.         if (iteration == 10) {
  262.             shift.exShift += shift.x;
  263.             for (int i = 0; i <= idx; i++) {
  264.                 matrixT[i][i] -= shift.x;
  265.             }
  266.             final double s = FastMath.abs(matrixT[idx][idx - 1]) + FastMath.abs(matrixT[idx - 1][idx - 2]);
  267.             shift.x = 0.75 * s;
  268.             shift.y = 0.75 * s;
  269.             shift.w = -0.4375 * s * s;
  270.         }

  271.         // MATLAB's new ad hoc shift
  272.         if (iteration == 30) {
  273.             double s = (shift.y - shift.x) / 2.0;
  274.             s = s * s + shift.w;
  275.             if (s > 0.0) {
  276.                 s = FastMath.sqrt(s);
  277.                 if (shift.y < shift.x) {
  278.                     s = -s;
  279.                 }
  280.                 s = shift.x - shift.w / ((shift.y - shift.x) / 2.0 + s);
  281.                 for (int i = 0; i <= idx; i++) {
  282.                     matrixT[i][i] -= s;
  283.                 }
  284.                 shift.exShift += s;
  285.                 shift.x = shift.y = shift.w = 0.964;
  286.             }
  287.         }
  288.     }

  289.     /**
  290.      * Initialize the householder vectors for the QR step.
  291.      *
  292.      * @param il the index of the small sub-diagonal element
  293.      * @param iu the current eigenvalue index
  294.      * @param shift shift information holder
  295.      * @param hVec the initial houseHolder vector
  296.      * @return the start index for the QR step
  297.      */
  298.     private int initQRStep(int il, final int iu, final ShiftInfo shift, double[] hVec) {
  299.         // Look for two consecutive small sub-diagonal elements
  300.         int im = iu - 2;
  301.         while (im >= il) {
  302.             final double z = matrixT[im][im];
  303.             final double r = shift.x - z;
  304.             double s = shift.y - z;
  305.             hVec[0] = (r * s - shift.w) / matrixT[im + 1][im] + matrixT[im][im + 1];
  306.             hVec[1] = matrixT[im + 1][im + 1] - z - r - s;
  307.             hVec[2] = matrixT[im + 2][im + 1];

  308.             if (im == il) {
  309.                 break;
  310.             }

  311.             final double lhs = FastMath.abs(matrixT[im][im - 1]) * (FastMath.abs(hVec[1]) + FastMath.abs(hVec[2]));
  312.             final double rhs = FastMath.abs(hVec[0]) * (FastMath.abs(matrixT[im - 1][im - 1]) +
  313.                                                         FastMath.abs(z) +
  314.                                                         FastMath.abs(matrixT[im + 1][im + 1]));

  315.             if (lhs < epsilon * rhs) {
  316.                 break;
  317.             }
  318.             im--;
  319.         }

  320.         return im;
  321.     }

  322.     /**
  323.      * Perform a double QR step involving rows l:idx and columns m:n
  324.      *
  325.      * @param il the index of the small sub-diagonal element
  326.      * @param im the start index for the QR step
  327.      * @param iu the current eigenvalue index
  328.      * @param shift shift information holder
  329.      * @param hVec the initial houseHolder vector
  330.      * @param norm matrix norm
  331.      */
  332.     private void performDoubleQRStep(final int il, final int im, final int iu,
  333.                                      final ShiftInfo shift, final double[] hVec,
  334.                                      final double norm) {

  335.         final int n = matrixT.length;
  336.         double p = hVec[0];
  337.         double q = hVec[1];
  338.         double r = hVec[2];

  339.         for (int k = im; k <= iu - 1; k++) {
  340.             boolean notlast = k != (iu - 1);
  341.             if (k != im) {
  342.                 p = matrixT[k][k - 1];
  343.                 q = matrixT[k + 1][k - 1];
  344.                 r = notlast ? matrixT[k + 2][k - 1] : 0.0;
  345.                 shift.x = FastMath.abs(p) + FastMath.abs(q) + FastMath.abs(r);
  346.                 if (Precision.equals(shift.x, 0.0, epsilon * norm)) {
  347.                     continue;
  348.                 }
  349.                 p /= shift.x;
  350.                 q /= shift.x;
  351.                 r /= shift.x;
  352.             }
  353.             double s = FastMath.sqrt(p * p + q * q + r * r);
  354.             if (p < 0.0) {
  355.                 s = -s;
  356.             }
  357.             if (s != 0.0) {
  358.                 if (k != im) {
  359.                     matrixT[k][k - 1] = -s * shift.x;
  360.                 } else if (il != im) {
  361.                     matrixT[k][k - 1] = -matrixT[k][k - 1];
  362.                 }
  363.                 p += s;
  364.                 shift.x = p / s;
  365.                 shift.y = q / s;
  366.                 double z = r / s;
  367.                 q /= p;
  368.                 r /= p;

  369.                 // Row modification
  370.                 for (int j = k; j < n; j++) {
  371.                     p = matrixT[k][j] + q * matrixT[k + 1][j];
  372.                     if (notlast) {
  373.                         p += r * matrixT[k + 2][j];
  374.                         matrixT[k + 2][j] -= p * z;
  375.                     }
  376.                     matrixT[k][j] -= p * shift.x;
  377.                     matrixT[k + 1][j] -= p * shift.y;
  378.                 }

  379.                 // Column modification
  380.                 for (int i = 0; i <= FastMath.min(iu, k + 3); i++) {
  381.                     p = shift.x * matrixT[i][k] + shift.y * matrixT[i][k + 1];
  382.                     if (notlast) {
  383.                         p += z * matrixT[i][k + 2];
  384.                         matrixT[i][k + 2] -= p * r;
  385.                     }
  386.                     matrixT[i][k] -= p;
  387.                     matrixT[i][k + 1] -= p * q;
  388.                 }

  389.                 // Accumulate transformations
  390.                 final int high = matrixT.length - 1;
  391.                 for (int i = 0; i <= high; i++) {
  392.                     p = shift.x * matrixP[i][k] + shift.y * matrixP[i][k + 1];
  393.                     if (notlast) {
  394.                         p += z * matrixP[i][k + 2];
  395.                         matrixP[i][k + 2] -= p * r;
  396.                     }
  397.                     matrixP[i][k] -= p;
  398.                     matrixP[i][k + 1] -= p * q;
  399.                 }
  400.             }  // (s != 0)
  401.         }  // k loop

  402.         // clean up pollution due to round-off errors
  403.         for (int i = im + 2; i <= iu; i++) {
  404.             matrixT[i][i-2] = 0.0;
  405.             if (i > im + 2) {
  406.                 matrixT[i][i-3] = 0.0;
  407.             }
  408.         }
  409.     }

  410.     /**
  411.      * Internal data structure holding the current shift information.
  412.      * Contains variable names as present in the original JAMA code.
  413.      */
  414.     private static class ShiftInfo {
  415.         // CHECKSTYLE: stop all

  416.         /** x shift info */
  417.         double x;
  418.         /** y shift info */
  419.         double y;
  420.         /** w shift info */
  421.         double w;
  422.         /** Indicates an exceptional shift. */
  423.         double exShift;

  424.         // CHECKSTYLE: resume all
  425.     }
  426. }