CholeskyDecomposition.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      https://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. /*
  18.  * This is not the original file distributed by the Apache Software Foundation
  19.  * It has been modified by the Hipparchus project
  20.  */

  21. package org.hipparchus.linear;

  22. import org.hipparchus.exception.LocalizedCoreFormats;
  23. import org.hipparchus.exception.MathIllegalArgumentException;
  24. import org.hipparchus.util.FastMath;


  25. /**
  26.  * Calculates the Cholesky decomposition of a matrix.
  27.  * <p>The Cholesky decomposition of a real symmetric positive-definite
  28.  * matrix A consists of a lower triangular matrix L with same size such
  29.  * that: A = LL<sup>T</sup>. In a sense, this is the square root of A.</p>
  30.  * <p>This class is based on the class with similar name from the
  31.  * <a href="http://math.nist.gov/javanumerics/jama/">JAMA</a> library, with the
  32.  * following changes:</p>
  33.  * <ul>
  34.  *   <li>a {@link #getLT() getLT} method has been added,</li>
  35.  *   <li>the {@code isspd} method has been removed, since the constructor of
  36.  *   this class throws a {@link MathIllegalArgumentException} when a
  37.  *   matrix cannot be decomposed,</li>
  38.  *   <li>a {@link #getDeterminant() getDeterminant} method has been added,</li>
  39.  *   <li>the {@code solve} method has been replaced by a {@link #getSolver()
  40.  *   getSolver} method and the equivalent method provided by the returned
  41.  *   {@link DecompositionSolver}.</li>
  42.  * </ul>
  43.  *
  44.  * @see <a href="http://mathworld.wolfram.com/CholeskyDecomposition.html">MathWorld</a>
  45.  * @see <a href="http://en.wikipedia.org/wiki/Cholesky_decomposition">Wikipedia</a>
  46.  */
  47. public class CholeskyDecomposition {
  48.     /**
  49.      * Default threshold above which off-diagonal elements are considered too different
  50.      * and matrix not symmetric.
  51.      */
  52.     public static final double DEFAULT_RELATIVE_SYMMETRY_THRESHOLD = 1.0e-15;
  53.     /**
  54.      * Default threshold below which diagonal elements are considered null
  55.      * and matrix not positive definite.
  56.      */
  57.     public static final double DEFAULT_ABSOLUTE_POSITIVITY_THRESHOLD = 1.0e-10;
  58.     /** Row-oriented storage for L<sup>T</sup> matrix data. */
  59.     private final double[][] lTData;
  60.     /** Cached value of L. */
  61.     private RealMatrix cachedL;
  62.     /** Cached value of LT. */
  63.     private RealMatrix cachedLT;

  64.     /**
  65.      * Calculates the Cholesky decomposition of the given matrix.
  66.      * <p>
  67.      * Calling this constructor is equivalent to call {@link
  68.      * #CholeskyDecomposition(RealMatrix, double, double)} with the
  69.      * thresholds set to the default values {@link
  70.      * #DEFAULT_RELATIVE_SYMMETRY_THRESHOLD} and {@link
  71.      * #DEFAULT_ABSOLUTE_POSITIVITY_THRESHOLD}
  72.      * </p>
  73.      * @param matrix the matrix to decompose
  74.      * @throws MathIllegalArgumentException if the matrix is not square.
  75.      * @throws MathIllegalArgumentException if the matrix is not symmetric.
  76.      * @throws MathIllegalArgumentException if the matrix is not
  77.      * strictly positive definite.
  78.      * @see #CholeskyDecomposition(RealMatrix, double, double)
  79.      * @see #DEFAULT_RELATIVE_SYMMETRY_THRESHOLD
  80.      * @see #DEFAULT_ABSOLUTE_POSITIVITY_THRESHOLD
  81.      */
  82.     public CholeskyDecomposition(final RealMatrix matrix) {
  83.         this(matrix, DEFAULT_RELATIVE_SYMMETRY_THRESHOLD,
  84.              DEFAULT_ABSOLUTE_POSITIVITY_THRESHOLD);
  85.     }

  86.     /**
  87.      * Calculates the Cholesky decomposition of the given matrix.
  88.      * @param matrix the matrix to decompose
  89.      * @param relativeSymmetryThreshold threshold above which off-diagonal
  90.      * elements are considered too different and matrix not symmetric
  91.      * @param absolutePositivityThreshold threshold below which diagonal
  92.      * elements are considered null and matrix not positive definite
  93.      * @throws MathIllegalArgumentException if the matrix is not square.
  94.      * @throws MathIllegalArgumentException if the matrix is not symmetric.
  95.      * @throws MathIllegalArgumentException if the matrix is not
  96.      * strictly positive definite.
  97.      * @see #CholeskyDecomposition(RealMatrix)
  98.      * @see #DEFAULT_RELATIVE_SYMMETRY_THRESHOLD
  99.      * @see #DEFAULT_ABSOLUTE_POSITIVITY_THRESHOLD
  100.      */
  101.     public CholeskyDecomposition(final RealMatrix matrix,
  102.                                  final double relativeSymmetryThreshold,
  103.                                  final double absolutePositivityThreshold) {
  104.         if (!matrix.isSquare()) {
  105.             throw new MathIllegalArgumentException(LocalizedCoreFormats.NON_SQUARE_MATRIX,
  106.                                                    matrix.getRowDimension(), matrix.getColumnDimension());
  107.         }

  108.         final int order = matrix.getRowDimension();
  109.         lTData   = matrix.getData();
  110.         cachedL  = null;
  111.         cachedLT = null;

  112.         // check the matrix before transformation
  113.         for (int i = 0; i < order; ++i) {
  114.             final double[] lI = lTData[i];

  115.             // check off-diagonal elements (and reset them to 0)
  116.             for (int j = i + 1; j < order; ++j) {
  117.                 final double[] lJ = lTData[j];
  118.                 final double lIJ = lI[j];
  119.                 final double lJI = lJ[i];
  120.                 final double maxDelta =
  121.                     relativeSymmetryThreshold * FastMath.max(FastMath.abs(lIJ), FastMath.abs(lJI));
  122.                 if (FastMath.abs(lIJ - lJI) > maxDelta) {
  123.                     throw new MathIllegalArgumentException(LocalizedCoreFormats.NON_SYMMETRIC_MATRIX,
  124.                                                            i, j, relativeSymmetryThreshold);
  125.                 }
  126.                 lJ[i] = 0;
  127.            }
  128.         }

  129.         // transform the matrix
  130.         for (int i = 0; i < order; ++i) {

  131.             final double[] ltI = lTData[i];

  132.             // check diagonal element
  133.             if (ltI[i] <= absolutePositivityThreshold) {
  134.                 throw new MathIllegalArgumentException(LocalizedCoreFormats.NOT_POSITIVE_DEFINITE_MATRIX);
  135.             }

  136.             ltI[i] = FastMath.sqrt(ltI[i]);
  137.             final double inverse = 1.0 / ltI[i];

  138.             for (int q = order - 1; q > i; --q) {
  139.                 ltI[q] *= inverse;
  140.                 final double[] ltQ = lTData[q];
  141.                 for (int p = q; p < order; ++p) {
  142.                     ltQ[p] -= ltI[q] * ltI[p];
  143.                 }
  144.             }
  145.         }
  146.     }

  147.     /**
  148.      * Returns the matrix L of the decomposition.
  149.      * <p>L is an lower-triangular matrix</p>
  150.      * @return the L matrix
  151.      */
  152.     public RealMatrix getL() {
  153.         if (cachedL == null) {
  154.             cachedL = getLT().transpose();
  155.         }
  156.         return cachedL;
  157.     }

  158.     /**
  159.      * Returns the transpose of the matrix L of the decomposition.
  160.      * <p>L<sup>T</sup> is an upper-triangular matrix</p>
  161.      * @return the transpose of the matrix L of the decomposition
  162.      */
  163.     public RealMatrix getLT() {

  164.         if (cachedLT == null) {
  165.             cachedLT = MatrixUtils.createRealMatrix(lTData);
  166.         }

  167.         // return the cached matrix
  168.         return cachedLT;
  169.     }

  170.     /**
  171.      * Return the determinant of the matrix
  172.      * @return determinant of the matrix
  173.      */
  174.     public double getDeterminant() {
  175.         double determinant = 1.0;
  176.         for (int i = 0; i < lTData.length; ++i) {
  177.             double lTii = lTData[i][i];
  178.             determinant *= lTii * lTii;
  179.         }
  180.         return determinant;
  181.     }

  182.     /**
  183.      * Get a solver for finding the A &times; X = B solution in least square sense.
  184.      * @return a solver
  185.      */
  186.     public DecompositionSolver getSolver() {
  187.         return new Solver();
  188.     }

  189.     /** Specialized solver. */
  190.     private class Solver implements DecompositionSolver {

  191.         /** {@inheritDoc} */
  192.         @Override
  193.         public boolean isNonSingular() {
  194.             // if we get this far, the matrix was positive definite, hence non-singular
  195.             return true;
  196.         }

  197.         /** {@inheritDoc} */
  198.         @Override
  199.         public RealVector solve(final RealVector b) {
  200.             final int m = lTData.length;
  201.             if (b.getDimension() != m) {
  202.                 throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH,
  203.                                                        b.getDimension(), m);
  204.             }

  205.             final double[] x = b.toArray();

  206.             // Solve LY = b
  207.             for (int j = 0; j < m; j++) {
  208.                 final double[] lJ = lTData[j];
  209.                 x[j] /= lJ[j];
  210.                 final double xJ = x[j];
  211.                 for (int i = j + 1; i < m; i++) {
  212.                     x[i] -= xJ * lJ[i];
  213.                 }
  214.             }

  215.             // Solve LTX = Y
  216.             for (int j = m - 1; j >= 0; j--) {
  217.                 x[j] /= lTData[j][j];
  218.                 final double xJ = x[j];
  219.                 for (int i = 0; i < j; i++) {
  220.                     x[i] -= xJ * lTData[i][j];
  221.                 }
  222.             }

  223.             return new ArrayRealVector(x, false);
  224.         }

  225.         /** {@inheritDoc} */
  226.         @Override
  227.         public RealMatrix solve(RealMatrix b) {
  228.             final int m = lTData.length;
  229.             if (b.getRowDimension() != m) {
  230.                 throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH,
  231.                                                        b.getRowDimension(), m);
  232.             }

  233.             final int nColB = b.getColumnDimension();
  234.             final double[][] x = b.getData();

  235.             // Solve LY = b
  236.             for (int j = 0; j < m; j++) {
  237.                 final double[] lJ = lTData[j];
  238.                 final double lJJ = lJ[j];
  239.                 final double[] xJ = x[j];
  240.                 for (int k = 0; k < nColB; ++k) {
  241.                     xJ[k] /= lJJ;
  242.                 }
  243.                 for (int i = j + 1; i < m; i++) {
  244.                     final double[] xI = x[i];
  245.                     final double lJI = lJ[i];
  246.                     for (int k = 0; k < nColB; ++k) {
  247.                         xI[k] -= xJ[k] * lJI;
  248.                     }
  249.                 }
  250.             }

  251.             // Solve LTX = Y
  252.             for (int j = m - 1; j >= 0; j--) {
  253.                 final double lJJ = lTData[j][j];
  254.                 final double[] xJ = x[j];
  255.                 for (int k = 0; k < nColB; ++k) {
  256.                     xJ[k] /= lJJ;
  257.                 }
  258.                 for (int i = 0; i < j; i++) {
  259.                     final double[] xI = x[i];
  260.                     final double lIJ = lTData[i][j];
  261.                     for (int k = 0; k < nColB; ++k) {
  262.                         xI[k] -= xJ[k] * lIJ;
  263.                     }
  264.                 }
  265.             }

  266.             return new Array2DRowRealMatrix(x);
  267.         }

  268.         /**
  269.          * Get the inverse of the decomposed matrix.
  270.          *
  271.          * @return the inverse matrix.
  272.          */
  273.         @Override
  274.         public RealMatrix getInverse() {
  275.             return solve(MatrixUtils.createRealIdentityMatrix(lTData.length));
  276.         }

  277.         /** {@inheritDoc} */
  278.         @Override
  279.         public int getRowDimension() {
  280.             return lTData.length;
  281.         }

  282.         /** {@inheritDoc} */
  283.         @Override
  284.         public int getColumnDimension() {
  285.             return lTData[0].length;
  286.         }

  287.     }

  288. }