1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * https://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 /*
19 * This is not the original file distributed by the Apache Software Foundation
20 * It has been modified by the Hipparchus project
21 */
22
23 package org.hipparchus.linear;
24
25 import java.util.function.Predicate;
26
27 import org.hipparchus.Field;
28 import org.hipparchus.FieldElement;
29 import org.hipparchus.exception.LocalizedCoreFormats;
30 import org.hipparchus.exception.MathIllegalArgumentException;
31 import org.hipparchus.util.FastMath;
32 import org.hipparchus.util.MathArrays;
33
34 /**
35 * Calculates the LUP-decomposition of a square matrix.
36 * <p>The LUP-decomposition of a matrix A consists of three matrices
37 * L, U and P that satisfy: PA = LU, L is lower triangular, and U is
38 * upper triangular and P is a permutation matrix. All matrices are
39 * m×m.</p>
40 * <p>This class is based on the class with similar name from the
41 * <a href="http://math.nist.gov/javanumerics/jama/">JAMA</a> library.</p>
42 * <ul>
43 * <li>a {@link #getP() getP} method has been added,</li>
44 * <li>the {@code det} method has been renamed as {@link #getDeterminant()
45 * getDeterminant},</li>
46 * <li>the {@code getDoublePivot} method has been removed (but the int based
47 * {@link #getPivot() getPivot} method has been kept),</li>
48 * <li>the {@code solve} and {@code isNonSingular} methods have been replaced
49 * by a {@link #getSolver() getSolver} method and the equivalent methods
50 * provided by the returned {@link DecompositionSolver}.</li>
51 * </ul>
52 *
53 * @param <T> the type of the field elements
54 * @see <a href="http://mathworld.wolfram.com/LUDecomposition.html">MathWorld</a>
55 * @see <a href="http://en.wikipedia.org/wiki/LU_decomposition">Wikipedia</a>
56 */
57 public class FieldLUDecomposition<T extends FieldElement<T>> {
58
59 /** Field to which the elements belong. */
60 private final Field<T> field;
61
62 /** Entries of LU decomposition. */
63 private T[][] lu;
64
65 /** Pivot permutation associated with LU decomposition. */
66 private int[] pivot;
67
68 /** Parity of the permutation associated with the LU decomposition. */
69 private boolean even;
70
71 /** Singularity indicator. */
72 private boolean singular;
73
74 /** Cached value of L. */
75 private FieldMatrix<T> cachedL;
76
77 /** Cached value of U. */
78 private FieldMatrix<T> cachedU;
79
80 /** Cached value of P. */
81 private FieldMatrix<T> cachedP;
82
83 /**
84 * Calculates the LU-decomposition of the given matrix.
85 * <p>
86 * By default, <code>numericPermutationChoice</code> is set to <code>true</code>.
87 * </p>
88 * @param matrix The matrix to decompose.
89 * @throws MathIllegalArgumentException if matrix is not square
90 * @see #FieldLUDecomposition(FieldMatrix, Predicate)
91 * @see #FieldLUDecomposition(FieldMatrix, Predicate, boolean)
92 */
93 public FieldLUDecomposition(FieldMatrix<T> matrix) {
94 this(matrix, FieldElement::isZero);
95 }
96
97 /**
98 * Calculates the LU-decomposition of the given matrix.
99 * <p>
100 * By default, <code>numericPermutationChoice</code> is set to <code>true</code>.
101 * </p>
102 * @param matrix The matrix to decompose.
103 * @param zeroChecker checker for zero elements
104 * @throws MathIllegalArgumentException if matrix is not square
105 * @see #FieldLUDecomposition(FieldMatrix, Predicate, boolean)
106 */
107 public FieldLUDecomposition(FieldMatrix<T> matrix, final Predicate<T> zeroChecker ) {
108 this(matrix, zeroChecker, true);
109 }
110
111 /**
112 * Calculates the LU-decomposition of the given matrix.
113 * @param matrix The matrix to decompose.
114 * @param zeroChecker checker for zero elements
115 * @param numericPermutationChoice if <code>true</code> choose permutation index with numeric calculations, otherwise choose with <code>zeroChecker</code>
116 * @throws MathIllegalArgumentException if matrix is not square
117 */
118 public FieldLUDecomposition(FieldMatrix<T> matrix, final Predicate<T> zeroChecker, boolean numericPermutationChoice) {
119 if (!matrix.isSquare()) {
120 throw new MathIllegalArgumentException(LocalizedCoreFormats.NON_SQUARE_MATRIX,
121 matrix.getRowDimension(), matrix.getColumnDimension());
122 }
123
124 final int m = matrix.getColumnDimension();
125 field = matrix.getField();
126 lu = matrix.getData();
127 pivot = new int[m];
128 cachedL = null;
129 cachedU = null;
130 cachedP = null;
131
132 // Initialize permutation array and parity
133 for (int row = 0; row < m; row++) {
134 pivot[row] = row;
135 }
136 even = true;
137 singular = false;
138
139 // Loop over columns
140 for (int col = 0; col < m; col++) {
141
142 // upper
143 for (int row = 0; row < col; row++) {
144 final T[] luRow = lu[row];
145 T sum = luRow[col];
146 for (int i = 0; i < row; i++) {
147 sum = sum.subtract(luRow[i].multiply(lu[i][col]));
148 }
149 luRow[col] = sum;
150 }
151
152 int max = col; // permutation row
153 if (numericPermutationChoice) {
154
155 // lower
156 double largest = Double.NEGATIVE_INFINITY;
157
158 for (int row = col; row < m; row++) {
159 final T[] luRow = lu[row];
160 T sum = luRow[col];
161 for (int i = 0; i < col; i++) {
162 sum = sum.subtract(luRow[i].multiply(lu[i][col]));
163 }
164 luRow[col] = sum;
165
166 // maintain best permutation choice
167 double absSum = FastMath.abs(sum.getReal());
168 if (absSum > largest) {
169 largest = absSum;
170 max = row;
171 }
172 }
173
174 } else {
175
176 // lower
177 int nonZero = col; // permutation row
178 for (int row = col; row < m; row++) {
179 final T[] luRow = lu[row];
180 T sum = luRow[col];
181 for (int i = 0; i < col; i++) {
182 sum = sum.subtract(luRow[i].multiply(lu[i][col]));
183 }
184 luRow[col] = sum;
185
186 if (zeroChecker.test(lu[nonZero][col])) {
187 // try to select a better permutation choice
188 ++nonZero;
189 }
190 }
191 max = FastMath.min(m - 1, nonZero);
192
193 }
194
195 // Singularity check
196 if (zeroChecker.test(lu[max][col])) {
197 singular = true;
198 return;
199 }
200
201 // Pivot if necessary
202 if (max != col) {
203 final T[] luMax = lu[max];
204 final T[] luCol = lu[col];
205 for (int i = 0; i < m; i++) {
206 final T tmp = luMax[i];
207 luMax[i] = luCol[i];
208 luCol[i] = tmp;
209 }
210 int temp = pivot[max];
211 pivot[max] = pivot[col];
212 pivot[col] = temp;
213 even = !even;
214 }
215
216 // Divide the lower elements by the "winning" diagonal elt.
217 final T luDiag = lu[col][col];
218 for (int row = col + 1; row < m; row++) {
219 lu[row][col] = lu[row][col].divide(luDiag);
220 }
221 }
222
223 }
224
225 /**
226 * Returns the matrix L of the decomposition.
227 * <p>L is a lower-triangular matrix</p>
228 * @return the L matrix (or null if decomposed matrix is singular)
229 */
230 public FieldMatrix<T> getL() {
231 if ((cachedL == null) && !singular) {
232 final int m = pivot.length;
233 cachedL = new Array2DRowFieldMatrix<>(field, m, m);
234 for (int i = 0; i < m; ++i) {
235 final T[] luI = lu[i];
236 for (int j = 0; j < i; ++j) {
237 cachedL.setEntry(i, j, luI[j]);
238 }
239 cachedL.setEntry(i, i, field.getOne());
240 }
241 }
242 return cachedL;
243 }
244
245 /**
246 * Returns the matrix U of the decomposition.
247 * <p>U is an upper-triangular matrix</p>
248 * @return the U matrix (or null if decomposed matrix is singular)
249 */
250 public FieldMatrix<T> getU() {
251 if ((cachedU == null) && !singular) {
252 final int m = pivot.length;
253 cachedU = new Array2DRowFieldMatrix<>(field, m, m);
254 for (int i = 0; i < m; ++i) {
255 final T[] luI = lu[i];
256 for (int j = i; j < m; ++j) {
257 cachedU.setEntry(i, j, luI[j]);
258 }
259 }
260 }
261 return cachedU;
262 }
263
264 /**
265 * Returns the P rows permutation matrix.
266 * <p>P is a sparse matrix with exactly one element set to 1.0 in
267 * each row and each column, all other elements being set to 0.0.</p>
268 * <p>The positions of the 1 elements are given by the {@link #getPivot()
269 * pivot permutation vector}.</p>
270 * @return the P rows permutation matrix (or null if decomposed matrix is singular)
271 * @see #getPivot()
272 */
273 public FieldMatrix<T> getP() {
274 if ((cachedP == null) && !singular) {
275 final int m = pivot.length;
276 cachedP = new Array2DRowFieldMatrix<>(field, m, m);
277 for (int i = 0; i < m; ++i) {
278 cachedP.setEntry(i, pivot[i], field.getOne());
279 }
280 }
281 return cachedP;
282 }
283
284 /**
285 * Returns the pivot permutation vector.
286 * @return the pivot permutation vector
287 * @see #getP()
288 */
289 public int[] getPivot() {
290 return pivot.clone();
291 }
292
293 /**
294 * Return the determinant of the matrix.
295 * @return determinant of the matrix
296 */
297 public T getDeterminant() {
298 if (singular) {
299 return field.getZero();
300 } else {
301 final int m = pivot.length;
302 T determinant = even ? field.getOne() : field.getZero().subtract(field.getOne());
303 for (int i = 0; i < m; i++) {
304 determinant = determinant.multiply(lu[i][i]);
305 }
306 return determinant;
307 }
308 }
309
310 /**
311 * Get a solver for finding the A × X = B solution in exact linear sense.
312 * @return a solver
313 */
314 public FieldDecompositionSolver<T> getSolver() {
315 return new Solver();
316 }
317
318 /** Specialized solver.
319 */
320 private class Solver implements FieldDecompositionSolver<T> {
321
322 /** {@inheritDoc} */
323 @Override
324 public boolean isNonSingular() {
325 return !singular;
326 }
327
328 /** {@inheritDoc} */
329 @Override
330 public FieldVector<T> solve(FieldVector<T> b) {
331 if (b instanceof ArrayFieldVector) {
332 return solve((ArrayFieldVector<T>) b);
333 } else {
334
335 final int m = pivot.length;
336 if (b.getDimension() != m) {
337 throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH,
338 b.getDimension(), m);
339 }
340 if (singular) {
341 throw new MathIllegalArgumentException(LocalizedCoreFormats.SINGULAR_MATRIX);
342 }
343
344 // Apply permutations to b
345 final T[] bp = MathArrays.buildArray(field, m);
346 for (int row = 0; row < m; row++) {
347 bp[row] = b.getEntry(pivot[row]);
348 }
349
350 // Solve LY = b
351 for (int col = 0; col < m; col++) {
352 final T bpCol = bp[col];
353 for (int i = col + 1; i < m; i++) {
354 bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
355 }
356 }
357
358 // Solve UX = Y
359 for (int col = m - 1; col >= 0; col--) {
360 bp[col] = bp[col].divide(lu[col][col]);
361 final T bpCol = bp[col];
362 for (int i = 0; i < col; i++) {
363 bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
364 }
365 }
366
367 return new ArrayFieldVector<>(field, bp, false);
368
369 }
370 }
371
372 /** Solve the linear equation A × X = B.
373 * <p>The A matrix is implicit here. It is </p>
374 * @param b right-hand side of the equation A × X = B
375 * @return a vector X such that A × X = B
376 * @throws MathIllegalArgumentException if the matrices dimensions do not match.
377 * @throws MathIllegalArgumentException if the decomposed matrix is singular.
378 */
379 public ArrayFieldVector<T> solve(ArrayFieldVector<T> b) {
380 final int m = pivot.length;
381 final int length = b.getDimension();
382 if (length != m) {
383 throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH,
384 length, m);
385 }
386 if (singular) {
387 throw new MathIllegalArgumentException(LocalizedCoreFormats.SINGULAR_MATRIX);
388 }
389
390 // Apply permutations to b
391 final T[] bp = MathArrays.buildArray(field, m);
392 for (int row = 0; row < m; row++) {
393 bp[row] = b.getEntry(pivot[row]);
394 }
395
396 // Solve LY = b
397 for (int col = 0; col < m; col++) {
398 final T bpCol = bp[col];
399 for (int i = col + 1; i < m; i++) {
400 bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
401 }
402 }
403
404 // Solve UX = Y
405 for (int col = m - 1; col >= 0; col--) {
406 bp[col] = bp[col].divide(lu[col][col]);
407 final T bpCol = bp[col];
408 for (int i = 0; i < col; i++) {
409 bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
410 }
411 }
412
413 return new ArrayFieldVector<>(bp, false);
414 }
415
416 /** {@inheritDoc} */
417 @Override
418 public FieldMatrix<T> solve(FieldMatrix<T> b) {
419 final int m = pivot.length;
420 if (b.getRowDimension() != m) {
421 throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH,
422 b.getRowDimension(), m);
423 }
424 if (singular) {
425 throw new MathIllegalArgumentException(LocalizedCoreFormats.SINGULAR_MATRIX);
426 }
427
428 final int nColB = b.getColumnDimension();
429
430 // Apply permutations to b
431 final T[][] bp = MathArrays.buildArray(field, m, nColB);
432 for (int row = 0; row < m; row++) {
433 final T[] bpRow = bp[row];
434 final int pRow = pivot[row];
435 for (int col = 0; col < nColB; col++) {
436 bpRow[col] = b.getEntry(pRow, col);
437 }
438 }
439
440 // Solve LY = b
441 for (int col = 0; col < m; col++) {
442 final T[] bpCol = bp[col];
443 for (int i = col + 1; i < m; i++) {
444 final T[] bpI = bp[i];
445 final T luICol = lu[i][col];
446 for (int j = 0; j < nColB; j++) {
447 bpI[j] = bpI[j].subtract(bpCol[j].multiply(luICol));
448 }
449 }
450 }
451
452 // Solve UX = Y
453 for (int col = m - 1; col >= 0; col--) {
454 final T[] bpCol = bp[col];
455 final T luDiag = lu[col][col];
456 for (int j = 0; j < nColB; j++) {
457 bpCol[j] = bpCol[j].divide(luDiag);
458 }
459 for (int i = 0; i < col; i++) {
460 final T[] bpI = bp[i];
461 final T luICol = lu[i][col];
462 for (int j = 0; j < nColB; j++) {
463 bpI[j] = bpI[j].subtract(bpCol[j].multiply(luICol));
464 }
465 }
466 }
467
468 return new Array2DRowFieldMatrix<>(field, bp, false);
469
470 }
471
472 /** {@inheritDoc} */
473 @Override
474 public FieldMatrix<T> getInverse() {
475 return solve(MatrixUtils.createFieldIdentityMatrix(field, pivot.length));
476 }
477
478 /** {@inheritDoc} */
479 @Override
480 public int getRowDimension() {
481 return lu.length;
482 }
483
484 /** {@inheritDoc} */
485 @Override
486 public int getColumnDimension() {
487 return lu[0].length;
488 }
489
490 }
491 }