FieldLUDecomposition.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      https://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. /*
  18.  * This is not the original file distributed by the Apache Software Foundation
  19.  * It has been modified by the Hipparchus project
  20.  */

  21. package org.hipparchus.linear;

  22. import java.util.function.Predicate;

  23. import org.hipparchus.Field;
  24. import org.hipparchus.FieldElement;
  25. import org.hipparchus.exception.LocalizedCoreFormats;
  26. import org.hipparchus.exception.MathIllegalArgumentException;
  27. import org.hipparchus.util.FastMath;
  28. import org.hipparchus.util.MathArrays;

  29. /**
  30.  * Calculates the LUP-decomposition of a square matrix.
  31.  * <p>The LUP-decomposition of a matrix A consists of three matrices
  32.  * L, U and P that satisfy: PA = LU, L is lower triangular, and U is
  33.  * upper triangular and P is a permutation matrix. All matrices are
  34.  * m&times;m.</p>
  35.  * <p>This class is based on the class with similar name from the
  36.  * <a href="http://math.nist.gov/javanumerics/jama/">JAMA</a> library.</p>
  37.  * <ul>
  38.  *   <li>a {@link #getP() getP} method has been added,</li>
  39.  *   <li>the {@code det} method has been renamed as {@link #getDeterminant()
  40.  *   getDeterminant},</li>
  41.  *   <li>the {@code getDoublePivot} method has been removed (but the int based
  42.  *   {@link #getPivot() getPivot} method has been kept),</li>
  43.  *   <li>the {@code solve} and {@code isNonSingular} methods have been replaced
  44.  *   by a {@link #getSolver() getSolver} method and the equivalent methods
  45.  *   provided by the returned {@link DecompositionSolver}.</li>
  46.  * </ul>
  47.  *
  48.  * @param <T> the type of the field elements
  49.  * @see <a href="http://mathworld.wolfram.com/LUDecomposition.html">MathWorld</a>
  50.  * @see <a href="http://en.wikipedia.org/wiki/LU_decomposition">Wikipedia</a>
  51.  */
  52. public class FieldLUDecomposition<T extends FieldElement<T>> {

  53.     /** Field to which the elements belong. */
  54.     private final Field<T> field;

  55.     /** Entries of LU decomposition. */
  56.     private T[][] lu;

  57.     /** Pivot permutation associated with LU decomposition. */
  58.     private int[] pivot;

  59.     /** Parity of the permutation associated with the LU decomposition. */
  60.     private boolean even;

  61.     /** Singularity indicator. */
  62.     private boolean singular;

  63.     /** Cached value of L. */
  64.     private FieldMatrix<T> cachedL;

  65.     /** Cached value of U. */
  66.     private FieldMatrix<T> cachedU;

  67.     /** Cached value of P. */
  68.     private FieldMatrix<T> cachedP;

  69.     /**
  70.      * Calculates the LU-decomposition of the given matrix.
  71.      * <p>
  72.      * By default, <code>numericPermutationChoice</code> is set to <code>true</code>.
  73.      * </p>
  74.      * @param matrix The matrix to decompose.
  75.      * @throws MathIllegalArgumentException if matrix is not square
  76.      * @see #FieldLUDecomposition(FieldMatrix, Predicate)
  77.      * @see #FieldLUDecomposition(FieldMatrix, Predicate, boolean)
  78.      */
  79.     public FieldLUDecomposition(FieldMatrix<T> matrix) {
  80.         this(matrix, FieldElement::isZero);
  81.     }

  82.     /**
  83.      * Calculates the LU-decomposition of the given matrix.
  84.      * <p>
  85.      * By default, <code>numericPermutationChoice</code> is set to <code>true</code>.
  86.      * </p>
  87.      * @param matrix The matrix to decompose.
  88.      * @param zeroChecker checker for zero elements
  89.      * @throws MathIllegalArgumentException if matrix is not square
  90.      * @see #FieldLUDecomposition(FieldMatrix, Predicate, boolean)
  91.      */
  92.     public FieldLUDecomposition(FieldMatrix<T> matrix, final Predicate<T> zeroChecker ) {
  93.         this(matrix, zeroChecker, true);
  94.     }

  95.     /**
  96.      * Calculates the LU-decomposition of the given matrix.
  97.      * @param matrix The matrix to decompose.
  98.      * @param zeroChecker checker for zero elements
  99.      * @param numericPermutationChoice if <code>true</code> choose permutation index with numeric calculations, otherwise choose with <code>zeroChecker</code>
  100.      * @throws MathIllegalArgumentException if matrix is not square
  101.      */
  102.     public FieldLUDecomposition(FieldMatrix<T> matrix, final Predicate<T> zeroChecker, boolean numericPermutationChoice) {
  103.         if (!matrix.isSquare()) {
  104.             throw new MathIllegalArgumentException(LocalizedCoreFormats.NON_SQUARE_MATRIX,
  105.                                                    matrix.getRowDimension(), matrix.getColumnDimension());
  106.         }

  107.         final int m = matrix.getColumnDimension();
  108.         field = matrix.getField();
  109.         lu = matrix.getData();
  110.         pivot = new int[m];
  111.         cachedL = null;
  112.         cachedU = null;
  113.         cachedP = null;

  114.         // Initialize permutation array and parity
  115.         for (int row = 0; row < m; row++) {
  116.             pivot[row] = row;
  117.         }
  118.         even     = true;
  119.         singular = false;

  120.         // Loop over columns
  121.         for (int col = 0; col < m; col++) {

  122.             // upper
  123.             for (int row = 0; row < col; row++) {
  124.                 final T[] luRow = lu[row];
  125.                 T sum = luRow[col];
  126.                 for (int i = 0; i < row; i++) {
  127.                     sum = sum.subtract(luRow[i].multiply(lu[i][col]));
  128.                 }
  129.                 luRow[col] = sum;
  130.             }

  131.             int max = col; // permutation row
  132.             if (numericPermutationChoice) {

  133.                 // lower
  134.                 double largest = Double.NEGATIVE_INFINITY;

  135.                 for (int row = col; row < m; row++) {
  136.                     final T[] luRow = lu[row];
  137.                     T sum = luRow[col];
  138.                     for (int i = 0; i < col; i++) {
  139.                         sum = sum.subtract(luRow[i].multiply(lu[i][col]));
  140.                     }
  141.                     luRow[col] = sum;

  142.                     // maintain best permutation choice
  143.                     double absSum = FastMath.abs(sum.getReal());
  144.                     if (absSum > largest) {
  145.                         largest = absSum;
  146.                         max = row;
  147.                     }
  148.                 }

  149.             } else {

  150.                 // lower
  151.                 int nonZero = col; // permutation row
  152.                 for (int row = col; row < m; row++) {
  153.                     final T[] luRow = lu[row];
  154.                     T sum = luRow[col];
  155.                     for (int i = 0; i < col; i++) {
  156.                         sum = sum.subtract(luRow[i].multiply(lu[i][col]));
  157.                     }
  158.                     luRow[col] = sum;

  159.                     if (zeroChecker.test(lu[nonZero][col])) {
  160.                         // try to select a better permutation choice
  161.                         ++nonZero;
  162.                     }
  163.                 }
  164.                 max = FastMath.min(m - 1, nonZero);

  165.             }

  166.             // Singularity check
  167.             if (zeroChecker.test(lu[max][col])) {
  168.                 singular = true;
  169.                 return;
  170.             }

  171.             // Pivot if necessary
  172.             if (max != col) {
  173.                 final T[] luMax = lu[max];
  174.                 final T[] luCol = lu[col];
  175.                 for (int i = 0; i < m; i++) {
  176.                     final T tmp = luMax[i];
  177.                     luMax[i] = luCol[i];
  178.                     luCol[i] = tmp;
  179.                 }
  180.                 int temp = pivot[max];
  181.                 pivot[max] = pivot[col];
  182.                 pivot[col] = temp;
  183.                 even = !even;
  184.             }

  185.             // Divide the lower elements by the "winning" diagonal elt.
  186.             final T luDiag = lu[col][col];
  187.             for (int row = col + 1; row < m; row++) {
  188.                 lu[row][col] = lu[row][col].divide(luDiag);
  189.             }
  190.         }

  191.     }

  192.     /**
  193.      * Returns the matrix L of the decomposition.
  194.      * <p>L is a lower-triangular matrix</p>
  195.      * @return the L matrix (or null if decomposed matrix is singular)
  196.      */
  197.     public FieldMatrix<T> getL() {
  198.         if ((cachedL == null) && !singular) {
  199.             final int m = pivot.length;
  200.             cachedL = new Array2DRowFieldMatrix<>(field, m, m);
  201.             for (int i = 0; i < m; ++i) {
  202.                 final T[] luI = lu[i];
  203.                 for (int j = 0; j < i; ++j) {
  204.                     cachedL.setEntry(i, j, luI[j]);
  205.                 }
  206.                 cachedL.setEntry(i, i, field.getOne());
  207.             }
  208.         }
  209.         return cachedL;
  210.     }

  211.     /**
  212.      * Returns the matrix U of the decomposition.
  213.      * <p>U is an upper-triangular matrix</p>
  214.      * @return the U matrix (or null if decomposed matrix is singular)
  215.      */
  216.     public FieldMatrix<T> getU() {
  217.         if ((cachedU == null) && !singular) {
  218.             final int m = pivot.length;
  219.             cachedU = new Array2DRowFieldMatrix<>(field, m, m);
  220.             for (int i = 0; i < m; ++i) {
  221.                 final T[] luI = lu[i];
  222.                 for (int j = i; j < m; ++j) {
  223.                     cachedU.setEntry(i, j, luI[j]);
  224.                 }
  225.             }
  226.         }
  227.         return cachedU;
  228.     }

  229.     /**
  230.      * Returns the P rows permutation matrix.
  231.      * <p>P is a sparse matrix with exactly one element set to 1.0 in
  232.      * each row and each column, all other elements being set to 0.0.</p>
  233.      * <p>The positions of the 1 elements are given by the {@link #getPivot()
  234.      * pivot permutation vector}.</p>
  235.      * @return the P rows permutation matrix (or null if decomposed matrix is singular)
  236.      * @see #getPivot()
  237.      */
  238.     public FieldMatrix<T> getP() {
  239.         if ((cachedP == null) && !singular) {
  240.             final int m = pivot.length;
  241.             cachedP = new Array2DRowFieldMatrix<>(field, m, m);
  242.             for (int i = 0; i < m; ++i) {
  243.                 cachedP.setEntry(i, pivot[i], field.getOne());
  244.             }
  245.         }
  246.         return cachedP;
  247.     }

  248.     /**
  249.      * Returns the pivot permutation vector.
  250.      * @return the pivot permutation vector
  251.      * @see #getP()
  252.      */
  253.     public int[] getPivot() {
  254.         return pivot.clone();
  255.     }

  256.     /**
  257.      * Return the determinant of the matrix.
  258.      * @return determinant of the matrix
  259.      */
  260.     public T getDeterminant() {
  261.         if (singular) {
  262.             return field.getZero();
  263.         } else {
  264.             final int m = pivot.length;
  265.             T determinant = even ? field.getOne() : field.getZero().subtract(field.getOne());
  266.             for (int i = 0; i < m; i++) {
  267.                 determinant = determinant.multiply(lu[i][i]);
  268.             }
  269.             return determinant;
  270.         }
  271.     }

  272.     /**
  273.      * Get a solver for finding the A &times; X = B solution in exact linear sense.
  274.      * @return a solver
  275.      */
  276.     public FieldDecompositionSolver<T> getSolver() {
  277.         return new Solver();
  278.     }

  279.     /** Specialized solver.
  280.      */
  281.     private class Solver implements FieldDecompositionSolver<T> {

  282.         /** {@inheritDoc} */
  283.         @Override
  284.         public boolean isNonSingular() {
  285.             return !singular;
  286.         }

  287.         /** {@inheritDoc} */
  288.         @Override
  289.         public FieldVector<T> solve(FieldVector<T> b) {
  290.             if (b instanceof ArrayFieldVector) {
  291.                 return solve((ArrayFieldVector<T>) b);
  292.             } else {

  293.                 final int m = pivot.length;
  294.                 if (b.getDimension() != m) {
  295.                     throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH,
  296.                                                            b.getDimension(), m);
  297.                 }
  298.                 if (singular) {
  299.                     throw new MathIllegalArgumentException(LocalizedCoreFormats.SINGULAR_MATRIX);
  300.                 }

  301.                 // Apply permutations to b
  302.                 final T[] bp = MathArrays.buildArray(field, m);
  303.                 for (int row = 0; row < m; row++) {
  304.                     bp[row] = b.getEntry(pivot[row]);
  305.                 }

  306.                 // Solve LY = b
  307.                 for (int col = 0; col < m; col++) {
  308.                     final T bpCol = bp[col];
  309.                     for (int i = col + 1; i < m; i++) {
  310.                         bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
  311.                     }
  312.                 }

  313.                 // Solve UX = Y
  314.                 for (int col = m - 1; col >= 0; col--) {
  315.                     bp[col] = bp[col].divide(lu[col][col]);
  316.                     final T bpCol = bp[col];
  317.                     for (int i = 0; i < col; i++) {
  318.                         bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
  319.                     }
  320.                 }

  321.                 return new ArrayFieldVector<>(field, bp, false);

  322.             }
  323.         }

  324.         /** Solve the linear equation A &times; X = B.
  325.          * <p>The A matrix is implicit here. It is </p>
  326.          * @param b right-hand side of the equation A &times; X = B
  327.          * @return a vector X such that A &times; X = B
  328.          * @throws MathIllegalArgumentException if the matrices dimensions do not match.
  329.          * @throws MathIllegalArgumentException if the decomposed matrix is singular.
  330.          */
  331.         public ArrayFieldVector<T> solve(ArrayFieldVector<T> b) {
  332.             final int m = pivot.length;
  333.             final int length = b.getDimension();
  334.             if (length != m) {
  335.                 throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH,
  336.                                                        length, m);
  337.             }
  338.             if (singular) {
  339.                 throw new MathIllegalArgumentException(LocalizedCoreFormats.SINGULAR_MATRIX);
  340.             }

  341.             // Apply permutations to b
  342.             final T[] bp = MathArrays.buildArray(field, m);
  343.             for (int row = 0; row < m; row++) {
  344.                 bp[row] = b.getEntry(pivot[row]);
  345.             }

  346.             // Solve LY = b
  347.             for (int col = 0; col < m; col++) {
  348.                 final T bpCol = bp[col];
  349.                 for (int i = col + 1; i < m; i++) {
  350.                     bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
  351.                 }
  352.             }

  353.             // Solve UX = Y
  354.             for (int col = m - 1; col >= 0; col--) {
  355.                 bp[col] = bp[col].divide(lu[col][col]);
  356.                 final T bpCol = bp[col];
  357.                 for (int i = 0; i < col; i++) {
  358.                     bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
  359.                 }
  360.             }

  361.             return new ArrayFieldVector<>(bp, false);
  362.         }

  363.         /** {@inheritDoc} */
  364.         @Override
  365.         public FieldMatrix<T> solve(FieldMatrix<T> b) {
  366.             final int m = pivot.length;
  367.             if (b.getRowDimension() != m) {
  368.                 throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH,
  369.                                                        b.getRowDimension(), m);
  370.             }
  371.             if (singular) {
  372.                 throw new MathIllegalArgumentException(LocalizedCoreFormats.SINGULAR_MATRIX);
  373.             }

  374.             final int nColB = b.getColumnDimension();

  375.             // Apply permutations to b
  376.             final T[][] bp = MathArrays.buildArray(field, m, nColB);
  377.             for (int row = 0; row < m; row++) {
  378.                 final T[] bpRow = bp[row];
  379.                 final int pRow = pivot[row];
  380.                 for (int col = 0; col < nColB; col++) {
  381.                     bpRow[col] = b.getEntry(pRow, col);
  382.                 }
  383.             }

  384.             // Solve LY = b
  385.             for (int col = 0; col < m; col++) {
  386.                 final T[] bpCol = bp[col];
  387.                 for (int i = col + 1; i < m; i++) {
  388.                     final T[] bpI = bp[i];
  389.                     final T luICol = lu[i][col];
  390.                     for (int j = 0; j < nColB; j++) {
  391.                         bpI[j] = bpI[j].subtract(bpCol[j].multiply(luICol));
  392.                     }
  393.                 }
  394.             }

  395.             // Solve UX = Y
  396.             for (int col = m - 1; col >= 0; col--) {
  397.                 final T[] bpCol = bp[col];
  398.                 final T luDiag = lu[col][col];
  399.                 for (int j = 0; j < nColB; j++) {
  400.                     bpCol[j] = bpCol[j].divide(luDiag);
  401.                 }
  402.                 for (int i = 0; i < col; i++) {
  403.                     final T[] bpI = bp[i];
  404.                     final T luICol = lu[i][col];
  405.                     for (int j = 0; j < nColB; j++) {
  406.                         bpI[j] = bpI[j].subtract(bpCol[j].multiply(luICol));
  407.                     }
  408.                 }
  409.             }

  410.             return new Array2DRowFieldMatrix<>(field, bp, false);

  411.         }

  412.         /** {@inheritDoc} */
  413.         @Override
  414.         public FieldMatrix<T> getInverse() {
  415.             return solve(MatrixUtils.createFieldIdentityMatrix(field, pivot.length));
  416.         }

  417.         /** {@inheritDoc} */
  418.         @Override
  419.         public int getRowDimension() {
  420.             return lu.length;
  421.         }

  422.         /** {@inheritDoc} */
  423.         @Override
  424.         public int getColumnDimension() {
  425.             return lu[0].length;
  426.         }

  427.     }
  428. }