FieldLUDecomposition.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* This is not the original file distributed by the Apache Software Foundation
* It has been modified by the Hipparchus project
*/
package org.hipparchus.linear;
import java.util.function.Predicate;
import org.hipparchus.Field;
import org.hipparchus.FieldElement;
import org.hipparchus.exception.LocalizedCoreFormats;
import org.hipparchus.exception.MathIllegalArgumentException;
import org.hipparchus.util.FastMath;
import org.hipparchus.util.MathArrays;
/**
* Calculates the LUP-decomposition of a square matrix.
* <p>The LUP-decomposition of a matrix A consists of three matrices
* L, U and P that satisfy: PA = LU, L is lower triangular, and U is
* upper triangular and P is a permutation matrix. All matrices are
* m×m.</p>
* <p>This class is based on the class with similar name from the
* <a href="http://math.nist.gov/javanumerics/jama/">JAMA</a> library.</p>
* <ul>
* <li>a {@link #getP() getP} method has been added,</li>
* <li>the {@code det} method has been renamed as {@link #getDeterminant()
* getDeterminant},</li>
* <li>the {@code getDoublePivot} method has been removed (but the int based
* {@link #getPivot() getPivot} method has been kept),</li>
* <li>the {@code solve} and {@code isNonSingular} methods have been replaced
* by a {@link #getSolver() getSolver} method and the equivalent methods
* provided by the returned {@link DecompositionSolver}.</li>
* </ul>
*
* @param <T> the type of the field elements
* @see <a href="http://mathworld.wolfram.com/LUDecomposition.html">MathWorld</a>
* @see <a href="http://en.wikipedia.org/wiki/LU_decomposition">Wikipedia</a>
*/
public class FieldLUDecomposition<T extends FieldElement<T>> {
/** Field to which the elements belong. */
private final Field<T> field;
/** Entries of LU decomposition. */
private T[][] lu;
/** Pivot permutation associated with LU decomposition. */
private int[] pivot;
/** Parity of the permutation associated with the LU decomposition. */
private boolean even;
/** Singularity indicator. */
private boolean singular;
/** Cached value of L. */
private FieldMatrix<T> cachedL;
/** Cached value of U. */
private FieldMatrix<T> cachedU;
/** Cached value of P. */
private FieldMatrix<T> cachedP;
/**
* Calculates the LU-decomposition of the given matrix.
* <p>
* By default, <code>numericPermutationChoice</code> is set to <code>true</code>.
* </p>
* @param matrix The matrix to decompose.
* @throws MathIllegalArgumentException if matrix is not square
* @see #FieldLUDecomposition(FieldMatrix, Predicate)
* @see #FieldLUDecomposition(FieldMatrix, Predicate, boolean)
*/
public FieldLUDecomposition(FieldMatrix<T> matrix) {
this(matrix, e -> e.isZero());
}
/**
* Calculates the LU-decomposition of the given matrix.
* <p>
* By default, <code>numericPermutationChoice</code> is set to <code>true</code>.
* </p>
* @param matrix The matrix to decompose.
* @param zeroChecker checker for zero elements
* @throws MathIllegalArgumentException if matrix is not square
* @see #FieldLUDecomposition(FieldMatrix, Predicate, boolean)
*/
public FieldLUDecomposition(FieldMatrix<T> matrix, final Predicate<T> zeroChecker ) {
this(matrix, zeroChecker, true);
}
/**
* Calculates the LU-decomposition of the given matrix.
* @param matrix The matrix to decompose.
* @param zeroChecker checker for zero elements
* @param numericPermutationChoice if <code>true</code> choose permutation index with numeric calculations, otherwise choose with <code>zeroChecker</code>
* @throws MathIllegalArgumentException if matrix is not square
*/
public FieldLUDecomposition(FieldMatrix<T> matrix, final Predicate<T> zeroChecker, boolean numericPermutationChoice) {
if (!matrix.isSquare()) {
throw new MathIllegalArgumentException(LocalizedCoreFormats.NON_SQUARE_MATRIX,
matrix.getRowDimension(), matrix.getColumnDimension());
}
final int m = matrix.getColumnDimension();
field = matrix.getField();
lu = matrix.getData();
pivot = new int[m];
cachedL = null;
cachedU = null;
cachedP = null;
// Initialize permutation array and parity
for (int row = 0; row < m; row++) {
pivot[row] = row;
}
even = true;
singular = false;
// Loop over columns
for (int col = 0; col < m; col++) {
// upper
for (int row = 0; row < col; row++) {
final T[] luRow = lu[row];
T sum = luRow[col];
for (int i = 0; i < row; i++) {
sum = sum.subtract(luRow[i].multiply(lu[i][col]));
}
luRow[col] = sum;
}
int max = col; // permutation row
if (numericPermutationChoice) {
// lower
double largest = Double.NEGATIVE_INFINITY;
for (int row = col; row < m; row++) {
final T[] luRow = lu[row];
T sum = luRow[col];
for (int i = 0; i < col; i++) {
sum = sum.subtract(luRow[i].multiply(lu[i][col]));
}
luRow[col] = sum;
// maintain best permutation choice
double absSum = FastMath.abs(sum.getReal());
if (absSum > largest) {
largest = absSum;
max = row;
}
}
} else {
// lower
int nonZero = col; // permutation row
for (int row = col; row < m; row++) {
final T[] luRow = lu[row];
T sum = luRow[col];
for (int i = 0; i < col; i++) {
sum = sum.subtract(luRow[i].multiply(lu[i][col]));
}
luRow[col] = sum;
if (zeroChecker.test(lu[nonZero][col])) {
// try to select a better permutation choice
++nonZero;
}
}
max = FastMath.min(m - 1, nonZero);
}
// Singularity check
if (zeroChecker.test(lu[max][col])) {
singular = true;
return;
}
// Pivot if necessary
if (max != col) {
final T[] luMax = lu[max];
final T[] luCol = lu[col];
for (int i = 0; i < m; i++) {
final T tmp = luMax[i];
luMax[i] = luCol[i];
luCol[i] = tmp;
}
int temp = pivot[max];
pivot[max] = pivot[col];
pivot[col] = temp;
even = !even;
}
// Divide the lower elements by the "winning" diagonal elt.
final T luDiag = lu[col][col];
for (int row = col + 1; row < m; row++) {
lu[row][col] = lu[row][col].divide(luDiag);
}
}
}
/**
* Returns the matrix L of the decomposition.
* <p>L is a lower-triangular matrix</p>
* @return the L matrix (or null if decomposed matrix is singular)
*/
public FieldMatrix<T> getL() {
if ((cachedL == null) && !singular) {
final int m = pivot.length;
cachedL = new Array2DRowFieldMatrix<>(field, m, m);
for (int i = 0; i < m; ++i) {
final T[] luI = lu[i];
for (int j = 0; j < i; ++j) {
cachedL.setEntry(i, j, luI[j]);
}
cachedL.setEntry(i, i, field.getOne());
}
}
return cachedL;
}
/**
* Returns the matrix U of the decomposition.
* <p>U is an upper-triangular matrix</p>
* @return the U matrix (or null if decomposed matrix is singular)
*/
public FieldMatrix<T> getU() {
if ((cachedU == null) && !singular) {
final int m = pivot.length;
cachedU = new Array2DRowFieldMatrix<>(field, m, m);
for (int i = 0; i < m; ++i) {
final T[] luI = lu[i];
for (int j = i; j < m; ++j) {
cachedU.setEntry(i, j, luI[j]);
}
}
}
return cachedU;
}
/**
* Returns the P rows permutation matrix.
* <p>P is a sparse matrix with exactly one element set to 1.0 in
* each row and each column, all other elements being set to 0.0.</p>
* <p>The positions of the 1 elements are given by the {@link #getPivot()
* pivot permutation vector}.</p>
* @return the P rows permutation matrix (or null if decomposed matrix is singular)
* @see #getPivot()
*/
public FieldMatrix<T> getP() {
if ((cachedP == null) && !singular) {
final int m = pivot.length;
cachedP = new Array2DRowFieldMatrix<>(field, m, m);
for (int i = 0; i < m; ++i) {
cachedP.setEntry(i, pivot[i], field.getOne());
}
}
return cachedP;
}
/**
* Returns the pivot permutation vector.
* @return the pivot permutation vector
* @see #getP()
*/
public int[] getPivot() {
return pivot.clone();
}
/**
* Return the determinant of the matrix.
* @return determinant of the matrix
*/
public T getDeterminant() {
if (singular) {
return field.getZero();
} else {
final int m = pivot.length;
T determinant = even ? field.getOne() : field.getZero().subtract(field.getOne());
for (int i = 0; i < m; i++) {
determinant = determinant.multiply(lu[i][i]);
}
return determinant;
}
}
/**
* Get a solver for finding the A × X = B solution in exact linear sense.
* @return a solver
*/
public FieldDecompositionSolver<T> getSolver() {
return new Solver();
}
/** Specialized solver.
*/
private class Solver implements FieldDecompositionSolver<T> {
/** {@inheritDoc} */
@Override
public boolean isNonSingular() {
return !singular;
}
/** {@inheritDoc} */
@Override
public FieldVector<T> solve(FieldVector<T> b) {
if (b instanceof ArrayFieldVector) {
return solve((ArrayFieldVector<T>) b);
} else {
final int m = pivot.length;
if (b.getDimension() != m) {
throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH,
b.getDimension(), m);
}
if (singular) {
throw new MathIllegalArgumentException(LocalizedCoreFormats.SINGULAR_MATRIX);
}
// Apply permutations to b
final T[] bp = MathArrays.buildArray(field, m);
for (int row = 0; row < m; row++) {
bp[row] = b.getEntry(pivot[row]);
}
// Solve LY = b
for (int col = 0; col < m; col++) {
final T bpCol = bp[col];
for (int i = col + 1; i < m; i++) {
bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
}
}
// Solve UX = Y
for (int col = m - 1; col >= 0; col--) {
bp[col] = bp[col].divide(lu[col][col]);
final T bpCol = bp[col];
for (int i = 0; i < col; i++) {
bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
}
}
return new ArrayFieldVector<T>(field, bp, false);
}
}
/** Solve the linear equation A × X = B.
* <p>The A matrix is implicit here. It is </p>
* @param b right-hand side of the equation A × X = B
* @return a vector X such that A × X = B
* @throws MathIllegalArgumentException if the matrices dimensions do not match.
* @throws MathIllegalArgumentException if the decomposed matrix is singular.
*/
public ArrayFieldVector<T> solve(ArrayFieldVector<T> b) {
final int m = pivot.length;
final int length = b.getDimension();
if (length != m) {
throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH,
length, m);
}
if (singular) {
throw new MathIllegalArgumentException(LocalizedCoreFormats.SINGULAR_MATRIX);
}
// Apply permutations to b
final T[] bp = MathArrays.buildArray(field, m);
for (int row = 0; row < m; row++) {
bp[row] = b.getEntry(pivot[row]);
}
// Solve LY = b
for (int col = 0; col < m; col++) {
final T bpCol = bp[col];
for (int i = col + 1; i < m; i++) {
bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
}
}
// Solve UX = Y
for (int col = m - 1; col >= 0; col--) {
bp[col] = bp[col].divide(lu[col][col]);
final T bpCol = bp[col];
for (int i = 0; i < col; i++) {
bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
}
}
return new ArrayFieldVector<T>(bp, false);
}
/** {@inheritDoc} */
@Override
public FieldMatrix<T> solve(FieldMatrix<T> b) {
final int m = pivot.length;
if (b.getRowDimension() != m) {
throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH,
b.getRowDimension(), m);
}
if (singular) {
throw new MathIllegalArgumentException(LocalizedCoreFormats.SINGULAR_MATRIX);
}
final int nColB = b.getColumnDimension();
// Apply permutations to b
final T[][] bp = MathArrays.buildArray(field, m, nColB);
for (int row = 0; row < m; row++) {
final T[] bpRow = bp[row];
final int pRow = pivot[row];
for (int col = 0; col < nColB; col++) {
bpRow[col] = b.getEntry(pRow, col);
}
}
// Solve LY = b
for (int col = 0; col < m; col++) {
final T[] bpCol = bp[col];
for (int i = col + 1; i < m; i++) {
final T[] bpI = bp[i];
final T luICol = lu[i][col];
for (int j = 0; j < nColB; j++) {
bpI[j] = bpI[j].subtract(bpCol[j].multiply(luICol));
}
}
}
// Solve UX = Y
for (int col = m - 1; col >= 0; col--) {
final T[] bpCol = bp[col];
final T luDiag = lu[col][col];
for (int j = 0; j < nColB; j++) {
bpCol[j] = bpCol[j].divide(luDiag);
}
for (int i = 0; i < col; i++) {
final T[] bpI = bp[i];
final T luICol = lu[i][col];
for (int j = 0; j < nColB; j++) {
bpI[j] = bpI[j].subtract(bpCol[j].multiply(luICol));
}
}
}
return new Array2DRowFieldMatrix<T>(field, bp, false);
}
/** {@inheritDoc} */
@Override
public FieldMatrix<T> getInverse() {
return solve(MatrixUtils.createFieldIdentityMatrix(field, pivot.length));
}
/** {@inheritDoc} */
@Override
public int getRowDimension() {
return lu.length;
}
/** {@inheritDoc} */
@Override
public int getColumnDimension() {
return lu[0].length;
}
}
}