1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23 package org.hipparchus.linear;
24
25 import java.util.Arrays;
26
27 import org.hipparchus.exception.LocalizedCoreFormats;
28 import org.hipparchus.exception.MathIllegalArgumentException;
29 import org.hipparchus.util.FastMath;
30 import org.junit.Assert;
31 import org.junit.Test;
32
33 public class TriDiagonalTransformerTest {
34
35 private double[][] testSquare5 = {
36 { 1, 2, 3, 1, 1 },
37 { 2, 1, 1, 3, 1 },
38 { 3, 1, 1, 1, 2 },
39 { 1, 3, 1, 2, 1 },
40 { 1, 1, 2, 1, 3 }
41 };
42
43 private double[][] testSquare3 = {
44 { 1, 3, 4 },
45 { 3, 2, 2 },
46 { 4, 2, 0 }
47 };
48
49 @Test
50 public void testNonSquare() {
51 try {
52 new TriDiagonalTransformer(MatrixUtils.createRealMatrix(new double[3][2]));
53 Assert.fail("an exception should have been thrown");
54 } catch (MathIllegalArgumentException ime) {
55 Assert.assertEquals(LocalizedCoreFormats.NON_SQUARE_MATRIX, ime.getSpecifier());
56 }
57 }
58
59 @Test
60 public void testAEqualQTQt() {
61 checkAEqualQTQt(MatrixUtils.createRealMatrix(testSquare5));
62 checkAEqualQTQt(MatrixUtils.createRealMatrix(testSquare3));
63 }
64
65 private void checkAEqualQTQt(RealMatrix matrix) {
66 TriDiagonalTransformer transformer = new TriDiagonalTransformer(matrix);
67 RealMatrix q = transformer.getQ();
68 RealMatrix qT = transformer.getQT();
69 RealMatrix t = transformer.getT();
70 double norm = q.multiply(t).multiply(qT).subtract(matrix).getNorm1();
71 Assert.assertEquals(0, norm, 4.0e-15);
72 }
73
74 @Test
75 public void testNoAccessBelowDiagonal() {
76 checkNoAccessBelowDiagonal(testSquare5);
77 checkNoAccessBelowDiagonal(testSquare3);
78 }
79
80 private void checkNoAccessBelowDiagonal(double[][] data) {
81 double[][] modifiedData = new double[data.length][];
82 for (int i = 0; i < data.length; ++i) {
83 modifiedData[i] = data[i].clone();
84 Arrays.fill(modifiedData[i], 0, i, Double.NaN);
85 }
86 RealMatrix matrix = MatrixUtils.createRealMatrix(modifiedData);
87 TriDiagonalTransformer transformer = new TriDiagonalTransformer(matrix);
88 RealMatrix q = transformer.getQ();
89 RealMatrix qT = transformer.getQT();
90 RealMatrix t = transformer.getT();
91 double norm = q.multiply(t).multiply(qT).subtract(MatrixUtils.createRealMatrix(data)).getNorm1();
92 Assert.assertEquals(0, norm, 4.0e-15);
93 }
94
95 @Test
96 public void testQOrthogonal() {
97 checkOrthogonal(new TriDiagonalTransformer(MatrixUtils.createRealMatrix(testSquare5)).getQ());
98 checkOrthogonal(new TriDiagonalTransformer(MatrixUtils.createRealMatrix(testSquare3)).getQ());
99 }
100
101 @Test
102 public void testQTOrthogonal() {
103 checkOrthogonal(new TriDiagonalTransformer(MatrixUtils.createRealMatrix(testSquare5)).getQT());
104 checkOrthogonal(new TriDiagonalTransformer(MatrixUtils.createRealMatrix(testSquare3)).getQT());
105 }
106
107 private void checkOrthogonal(RealMatrix m) {
108 RealMatrix mTm = m.transposeMultiply(m);
109 RealMatrix id = MatrixUtils.createRealIdentityMatrix(mTm.getRowDimension());
110 Assert.assertEquals(0, mTm.subtract(id).getNorm1(), 1.0e-15);
111 }
112
113 @Test
114 public void testTTriDiagonal() {
115 checkTriDiagonal(new TriDiagonalTransformer(MatrixUtils.createRealMatrix(testSquare5)).getT());
116 checkTriDiagonal(new TriDiagonalTransformer(MatrixUtils.createRealMatrix(testSquare3)).getT());
117 }
118
119 private void checkTriDiagonal(RealMatrix m) {
120 final int rows = m.getRowDimension();
121 final int cols = m.getColumnDimension();
122 for (int i = 0; i < rows; ++i) {
123 for (int j = 0; j < cols; ++j) {
124 if ((i < j - 1) || (i > j + 1)) {
125 Assert.assertEquals(0, m.getEntry(i, j), 1.0e-16);
126 }
127 }
128 }
129 }
130
131 @Test
132 public void testMatricesValues5() {
133 checkMatricesValues(testSquare5,
134 new double[][] {
135 { 1.0, 0.0, 0.0, 0.0, 0.0 },
136 { 0.0, -0.5163977794943222, 0.016748280772542083, 0.839800693771262, 0.16669620021405473 },
137 { 0.0, -0.7745966692414833, -0.4354553000860955, -0.44989322880603355, -0.08930153582895772 },
138 { 0.0, -0.2581988897471611, 0.6364346693566014, -0.30263204032131164, 0.6608313651342882 },
139 { 0.0, -0.2581988897471611, 0.6364346693566009, -0.027289660803112598, -0.7263191580755246 }
140 },
141 new double[] { 1, 4.4, 1.433099579242636, -0.89537362758743, 2.062274048344794 },
142 new double[] { -FastMath.sqrt(15), -3.0832882879592476, 0.6082710842351517, 1.1786086405912128 });
143 }
144
145 @Test
146 public void testMatricesValues3() {
147 checkMatricesValues(testSquare3,
148 new double[][] {
149 { 1.0, 0.0, 0.0 },
150 { 0.0, -0.6, 0.8 },
151 { 0.0, -0.8, -0.6 },
152 },
153 new double[] { 1, 2.64, -0.64 },
154 new double[] { -5, -1.52 });
155 }
156
157 private void checkMatricesValues(double[][] matrix, double[][] qRef,
158 double[] mainDiagnonal,
159 double[] secondaryDiagonal) {
160 TriDiagonalTransformer transformer =
161 new TriDiagonalTransformer(MatrixUtils.createRealMatrix(matrix));
162
163
164 RealMatrix q = transformer.getQ();
165 Assert.assertEquals(0, q.subtract(MatrixUtils.createRealMatrix(qRef)).getNorm1(), 1.0e-14);
166
167 RealMatrix t = transformer.getT();
168 double[][] tData = new double[mainDiagnonal.length][mainDiagnonal.length];
169 for (int i = 0; i < mainDiagnonal.length; ++i) {
170 tData[i][i] = mainDiagnonal[i];
171 if (i > 0) {
172 tData[i][i - 1] = secondaryDiagonal[i - 1];
173 }
174 if (i < secondaryDiagonal.length) {
175 tData[i][i + 1] = secondaryDiagonal[i];
176 }
177 }
178 Assert.assertEquals(0, t.subtract(MatrixUtils.createRealMatrix(tData)).getNorm1(), 1.0e-14);
179
180
181 Assert.assertTrue(q == transformer.getQ());
182 Assert.assertTrue(t == transformer.getT());
183 }
184 }