1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22 package org.hipparchus.linear;
23
24 import java.util.Arrays;
25
26 import org.hipparchus.exception.MathIllegalArgumentException;
27 import org.hipparchus.exception.MathIllegalStateException;
28 import org.hipparchus.exception.MathRuntimeException;
29 import org.hipparchus.util.FastMath;
30 import org.hipparchus.util.IterationEvent;
31 import org.hipparchus.util.IterationListener;
32 import org.junit.Assert;
33 import org.junit.Test;
34
35 public class ConjugateGradientTest {
36
37 @Test(expected = MathIllegalArgumentException.class)
38 public void testNonSquareOperator() {
39 final Array2DRowRealMatrix a = new Array2DRowRealMatrix(2, 3);
40 final IterativeLinearSolver solver;
41 solver = new ConjugateGradient(10, 0., false);
42 final ArrayRealVector b = new ArrayRealVector(a.getRowDimension());
43 final ArrayRealVector x = new ArrayRealVector(a.getColumnDimension());
44 solver.solve(a, b, x);
45 }
46
47 @Test(expected = MathIllegalArgumentException.class)
48 public void testDimensionMismatchRightHandSide() {
49 final Array2DRowRealMatrix a = new Array2DRowRealMatrix(3, 3);
50 final IterativeLinearSolver solver;
51 solver = new ConjugateGradient(10, 0., false);
52 final ArrayRealVector b = new ArrayRealVector(2);
53 final ArrayRealVector x = new ArrayRealVector(3);
54 solver.solve(a, b, x);
55 }
56
57 @Test(expected = MathIllegalArgumentException.class)
58 public void testDimensionMismatchSolution() {
59 final Array2DRowRealMatrix a = new Array2DRowRealMatrix(3, 3);
60 final IterativeLinearSolver solver;
61 solver = new ConjugateGradient(10, 0., false);
62 final ArrayRealVector b = new ArrayRealVector(3);
63 final ArrayRealVector x = new ArrayRealVector(2);
64 solver.solve(a, b, x);
65 }
66
67 @Test(expected = MathIllegalArgumentException.class)
68 public void testNonPositiveDefiniteLinearOperator() {
69 final Array2DRowRealMatrix a = new Array2DRowRealMatrix(2, 2);
70 a.setEntry(0, 0, -1.);
71 a.setEntry(0, 1, 2.);
72 a.setEntry(1, 0, 3.);
73 a.setEntry(1, 1, 4.);
74 final IterativeLinearSolver solver;
75 solver = new ConjugateGradient(10, 0., true);
76 final ArrayRealVector b = new ArrayRealVector(2);
77 b.setEntry(0, -1.);
78 b.setEntry(1, -1.);
79 final ArrayRealVector x = new ArrayRealVector(2);
80 solver.solve(a, b, x);
81 }
82
83 @Test
84 public void testUnpreconditionedSolution() {
85 final int n = 5;
86 final int maxIterations = 100;
87 final RealLinearOperator a = new HilbertMatrix(n);
88 final InverseHilbertMatrix ainv = new InverseHilbertMatrix(n);
89 final IterativeLinearSolver solver;
90 solver = new ConjugateGradient(maxIterations, 1E-10, true);
91 final RealVector b = new ArrayRealVector(n);
92 for (int j = 0; j < n; j++) {
93 b.set(0.);
94 b.setEntry(j, 1.);
95 final RealVector x = solver.solve(a, b);
96 for (int i = 0; i < n; i++) {
97 final double actual = x.getEntry(i);
98 final double expected = ainv.getEntry(i, j);
99 final double delta = 1E-10 * FastMath.abs(expected);
100 final String msg = String.format("entry[%d][%d]", i, j);
101 Assert.assertEquals(msg, expected, actual, delta);
102 }
103 }
104 }
105
106 @Test
107 public void testUnpreconditionedInPlaceSolutionWithInitialGuess() {
108 final int n = 5;
109 final int maxIterations = 100;
110 final RealLinearOperator a = new HilbertMatrix(n);
111 final InverseHilbertMatrix ainv = new InverseHilbertMatrix(n);
112 final IterativeLinearSolver solver;
113 solver = new ConjugateGradient(maxIterations, 1E-10, true);
114 final RealVector b = new ArrayRealVector(n);
115 for (int j = 0; j < n; j++) {
116 b.set(0.);
117 b.setEntry(j, 1.);
118 final RealVector x0 = new ArrayRealVector(n);
119 x0.set(1.);
120 final RealVector x = solver.solveInPlace(a, b, x0);
121 Assert.assertSame("x should be a reference to x0", x0, x);
122 for (int i = 0; i < n; i++) {
123 final double actual = x.getEntry(i);
124 final double expected = ainv.getEntry(i, j);
125 final double delta = 1E-10 * FastMath.abs(expected);
126 final String msg = String.format("entry[%d][%d)", i, j);
127 Assert.assertEquals(msg, expected, actual, delta);
128 }
129 }
130 }
131
132 @Test
133 public void testUnpreconditionedSolutionWithInitialGuess() {
134 final int n = 5;
135 final int maxIterations = 100;
136 final RealLinearOperator a = new HilbertMatrix(n);
137 final InverseHilbertMatrix ainv = new InverseHilbertMatrix(n);
138 final IterativeLinearSolver solver;
139 solver = new ConjugateGradient(maxIterations, 1E-10, true);
140 final RealVector b = new ArrayRealVector(n);
141 for (int j = 0; j < n; j++) {
142 b.set(0.);
143 b.setEntry(j, 1.);
144 final RealVector x0 = new ArrayRealVector(n);
145 x0.set(1.);
146 final RealVector x = solver.solve(a, b, x0);
147 Assert.assertNotSame("x should not be a reference to x0", x0, x);
148 for (int i = 0; i < n; i++) {
149 final double actual = x.getEntry(i);
150 final double expected = ainv.getEntry(i, j);
151 final double delta = 1E-10 * FastMath.abs(expected);
152 final String msg = String.format("entry[%d][%d]", i, j);
153 Assert.assertEquals(msg, expected, actual, delta);
154 Assert.assertEquals(msg, x0.getEntry(i), 1., Math.ulp(1.));
155 }
156 }
157 }
158
159
160
161
162
163
164
165 @Test
166 public void testUnpreconditionedResidual() {
167 final int n = 10;
168 final int maxIterations = n;
169 final RealLinearOperator a = new HilbertMatrix(n);
170 final ConjugateGradient solver;
171 solver = new ConjugateGradient(maxIterations, 1E-15, true);
172 final RealVector r = new ArrayRealVector(n);
173 final RealVector x = new ArrayRealVector(n);
174 final IterationListener listener = new IterationListener() {
175
176 public void terminationPerformed(final IterationEvent e) {
177
178 }
179
180 public void iterationStarted(final IterationEvent e) {
181
182 }
183
184 public void iterationPerformed(final IterationEvent e) {
185 final IterativeLinearSolverEvent evt;
186 evt = (IterativeLinearSolverEvent) e;
187 RealVector v = evt.getResidual();
188 r.setSubVector(0, v);
189 v = evt.getSolution();
190 x.setSubVector(0, v);
191 }
192
193 public void initializationPerformed(final IterationEvent e) {
194
195 }
196 };
197 solver.getIterationManager().addIterationListener(listener);
198 final RealVector b = new ArrayRealVector(n);
199 for (int j = 0; j < n; j++) {
200 b.set(0.);
201 b.setEntry(j, 1.);
202
203 boolean caught = false;
204 try {
205 solver.solve(a, b);
206 } catch (MathIllegalStateException e) {
207 caught = true;
208 final RealVector y = a.operate(x);
209 for (int i = 0; i < n; i++) {
210 final double actual = b.getEntry(i) - y.getEntry(i);
211 final double expected = r.getEntry(i);
212 final double delta = 1E-6 * FastMath.abs(expected);
213 final String msg = String
214 .format("column %d, residual %d", i, j);
215 Assert.assertEquals(msg, expected, actual, delta);
216 }
217 }
218 Assert
219 .assertTrue("MathIllegalStateException should have been caught",
220 caught);
221 }
222 }
223
224 @Test(expected = MathIllegalArgumentException.class)
225 public void testNonSquarePreconditioner() {
226 final Array2DRowRealMatrix a = new Array2DRowRealMatrix(2, 2);
227 final RealLinearOperator m = new RealLinearOperator() {
228
229 @Override
230 public RealVector operate(final RealVector x) {
231 throw new UnsupportedOperationException();
232 }
233
234 @Override
235 public int getRowDimension() {
236 return 2;
237 }
238
239 @Override
240 public int getColumnDimension() {
241 return 3;
242 }
243 };
244 final PreconditionedIterativeLinearSolver solver;
245 solver = new ConjugateGradient(10, 0d, false);
246 final ArrayRealVector b = new ArrayRealVector(a.getRowDimension());
247 solver.solve(a, m, b);
248 }
249
250 @Test(expected = MathIllegalArgumentException.class)
251 public void testMismatchedOperatorDimensions() {
252 final Array2DRowRealMatrix a = new Array2DRowRealMatrix(2, 2);
253 final RealLinearOperator m = new RealLinearOperator() {
254
255 @Override
256 public RealVector operate(final RealVector x) {
257 throw new UnsupportedOperationException();
258 }
259
260 @Override
261 public int getRowDimension() {
262 return 3;
263 }
264
265 @Override
266 public int getColumnDimension() {
267 return 3;
268 }
269 };
270 final PreconditionedIterativeLinearSolver solver;
271 solver = new ConjugateGradient(10, 0d, false);
272 final ArrayRealVector b = new ArrayRealVector(a.getRowDimension());
273 solver.solve(a, m, b);
274 }
275
276 @Test(expected = MathIllegalArgumentException.class)
277 public void testNonPositiveDefinitePreconditioner() {
278 final Array2DRowRealMatrix a = new Array2DRowRealMatrix(2, 2);
279 a.setEntry(0, 0, 1d);
280 a.setEntry(0, 1, 2d);
281 a.setEntry(1, 0, 3d);
282 a.setEntry(1, 1, 4d);
283 final RealLinearOperator m = new RealLinearOperator() {
284
285 @Override
286 public RealVector operate(final RealVector x) {
287 final ArrayRealVector y = new ArrayRealVector(2);
288 y.setEntry(0, -x.getEntry(0));
289 y.setEntry(1, x.getEntry(1));
290 return y;
291 }
292
293 @Override
294 public int getRowDimension() {
295 return 2;
296 }
297
298 @Override
299 public int getColumnDimension() {
300 return 2;
301 }
302 };
303 final PreconditionedIterativeLinearSolver solver;
304 solver = new ConjugateGradient(10, 0d, true);
305 final ArrayRealVector b = new ArrayRealVector(2);
306 b.setEntry(0, -1d);
307 b.setEntry(1, -1d);
308 solver.solve(a, m, b);
309 }
310
311 @Test
312 public void testPreconditionedSolution() {
313 final int n = 8;
314 final int maxIterations = 100;
315 final RealLinearOperator a = new HilbertMatrix(n);
316 final InverseHilbertMatrix ainv = new InverseHilbertMatrix(n);
317 final RealLinearOperator m = JacobiPreconditioner.create(a);
318 final PreconditionedIterativeLinearSolver solver;
319 solver = new ConjugateGradient(maxIterations, 1E-15, true);
320 final RealVector b = new ArrayRealVector(n);
321 for (int j = 0; j < n; j++) {
322 b.set(0.);
323 b.setEntry(j, 1.);
324 final RealVector x = solver.solve(a, m, b);
325 for (int i = 0; i < n; i++) {
326 final double actual = x.getEntry(i);
327 final double expected = ainv.getEntry(i, j);
328 final double delta = 1E-6 * FastMath.abs(expected);
329 final String msg = String.format("coefficient (%d, %d)", i, j);
330 Assert.assertEquals(msg, expected, actual, delta);
331 }
332 }
333 }
334
335 @Test
336 public void testPreconditionedResidual() {
337 final int n = 10;
338 final int maxIterations = n;
339 final RealLinearOperator a = new HilbertMatrix(n);
340 final RealLinearOperator m = JacobiPreconditioner.create(a);
341 final ConjugateGradient solver;
342 solver = new ConjugateGradient(maxIterations, 1E-15, true);
343 final RealVector r = new ArrayRealVector(n);
344 final RealVector x = new ArrayRealVector(n);
345 final IterationListener listener = new IterationListener() {
346
347 public void terminationPerformed(final IterationEvent e) {
348
349 }
350
351 public void iterationStarted(final IterationEvent e) {
352
353 }
354
355 public void iterationPerformed(final IterationEvent e) {
356 final IterativeLinearSolverEvent evt;
357 evt = (IterativeLinearSolverEvent) e;
358 RealVector v = evt.getResidual();
359 r.setSubVector(0, v);
360 v = evt.getSolution();
361 x.setSubVector(0, v);
362 }
363
364 public void initializationPerformed(final IterationEvent e) {
365
366 }
367 };
368 solver.getIterationManager().addIterationListener(listener);
369 final RealVector b = new ArrayRealVector(n);
370
371 for (int j = 0; j < n; j++) {
372 b.set(0.);
373 b.setEntry(j, 1.);
374
375 boolean caught = false;
376 try {
377 solver.solve(a, m, b);
378 } catch (MathIllegalStateException e) {
379 caught = true;
380 final RealVector y = a.operate(x);
381 for (int i = 0; i < n; i++) {
382 final double actual = b.getEntry(i) - y.getEntry(i);
383 final double expected = r.getEntry(i);
384 final double delta = 1E-6 * FastMath.abs(expected);
385 final String msg = String.format("column %d, residual %d", i, j);
386 Assert.assertEquals(msg, expected, actual, delta);
387 }
388 }
389 Assert.assertTrue("MathIllegalStateException should have been caught", caught);
390 }
391 }
392
393 @Test
394 public void testPreconditionedSolution2() {
395 final int n = 100;
396 final int maxIterations = 100000;
397 final Array2DRowRealMatrix a = new Array2DRowRealMatrix(n, n);
398 double daux = 1.;
399 for (int i = 0; i < n; i++) {
400 a.setEntry(i, i, daux);
401 daux *= 1.2;
402 for (int j = i + 1; j < n; j++) {
403 if (i == j) {
404 } else {
405 final double value = 1.0;
406 a.setEntry(i, j, value);
407 a.setEntry(j, i, value);
408 }
409 }
410 }
411 final RealLinearOperator m = JacobiPreconditioner.create(a);
412 final PreconditionedIterativeLinearSolver pcg;
413 final IterativeLinearSolver cg;
414 pcg = new ConjugateGradient(maxIterations, 1E-6, true);
415 cg = new ConjugateGradient(maxIterations, 1E-6, true);
416 final RealVector b = new ArrayRealVector(n);
417 final String pattern = "preconditioned gradient (%d iterations) should"
418 + " have been faster than unpreconditioned (%d iterations)";
419 String msg;
420 for (int j = 0; j < 1; j++) {
421 b.set(0.);
422 b.setEntry(j, 1.);
423 final RealVector px = pcg.solve(a, m, b);
424 final RealVector x = cg.solve(a, b);
425 final int npcg = pcg.getIterationManager().getIterations();
426 final int ncg = cg.getIterationManager().getIterations();
427 msg = String.format(pattern, npcg, ncg);
428 Assert.assertTrue(msg, npcg < ncg);
429 for (int i = 0; i < n; i++) {
430 msg = String.format("row %d, column %d", i, j);
431 final double expected = x.getEntry(i);
432 final double actual = px.getEntry(i);
433 final double delta = 1E-6 * FastMath.abs(expected);
434 Assert.assertEquals(msg, expected, actual, delta);
435 }
436 }
437 }
438
439 @Test
440 public void testEventManagement() {
441 final int n = 5;
442 final int maxIterations = 100;
443 final RealLinearOperator a = new HilbertMatrix(n);
444 final IterativeLinearSolver solver;
445
446
447
448
449
450
451 final int[] count = new int[] {0, 0, 0, 0};
452 final IterationListener listener = new IterationListener() {
453 private void doTestVectorsAreUnmodifiable(final IterationEvent e) {
454 final IterativeLinearSolverEvent evt;
455 evt = (IterativeLinearSolverEvent) e;
456 try {
457 evt.getResidual().set(0.0);
458 Assert.fail("r is modifiable");
459 } catch (MathRuntimeException exc){
460
461 }
462 try {
463 evt.getRightHandSideVector().set(0.0);
464 Assert.fail("b is modifiable");
465 } catch (MathRuntimeException exc){
466
467 }
468 try {
469 evt.getSolution().set(0.0);
470 Assert.fail("x is modifiable");
471 } catch (MathRuntimeException exc){
472
473 }
474 }
475
476 public void initializationPerformed(final IterationEvent e) {
477 ++count[0];
478 doTestVectorsAreUnmodifiable(e);
479 }
480
481 public void iterationPerformed(final IterationEvent e) {
482 ++count[2];
483 Assert.assertEquals("iteration performed",
484 count[2], e.getIterations() - 1);
485 doTestVectorsAreUnmodifiable(e);
486 }
487
488 public void iterationStarted(final IterationEvent e) {
489 ++count[1];
490 Assert.assertEquals("iteration started",
491 count[1], e.getIterations() - 1);
492 doTestVectorsAreUnmodifiable(e);
493 }
494
495 public void terminationPerformed(final IterationEvent e) {
496 ++count[3];
497 doTestVectorsAreUnmodifiable(e);
498 }
499 };
500 solver = new ConjugateGradient(maxIterations, 1E-10, true);
501 solver.getIterationManager().addIterationListener(listener);
502 final RealVector b = new ArrayRealVector(n);
503 for (int j = 0; j < n; j++) {
504 Arrays.fill(count, 0);
505 b.set(0.);
506 b.setEntry(j, 1.);
507 solver.solve(a, b);
508 String msg = String.format("column %d (initialization)", j);
509 Assert.assertEquals(msg, 1, count[0]);
510 msg = String.format("column %d (finalization)", j);
511 Assert.assertEquals(msg, 1, count[3]);
512 }
513 }
514
515 @Test
516 public void testUnpreconditionedNormOfResidual() {
517 final int n = 5;
518 final int maxIterations = 100;
519 final RealLinearOperator a = new HilbertMatrix(n);
520 final IterativeLinearSolver solver;
521 final IterationListener listener = new IterationListener() {
522
523 private void doTestNormOfResidual(final IterationEvent e) {
524 final IterativeLinearSolverEvent evt;
525 evt = (IterativeLinearSolverEvent) e;
526 final RealVector x = evt.getSolution();
527 final RealVector b = evt.getRightHandSideVector();
528 final RealVector r = b.subtract(a.operate(x));
529 final double rnorm = r.getNorm();
530 Assert.assertEquals("iteration performed (residual)",
531 rnorm, evt.getNormOfResidual(),
532 FastMath.max(1E-5 * rnorm, 1E-10));
533 }
534
535 public void initializationPerformed(final IterationEvent e) {
536 doTestNormOfResidual(e);
537 }
538
539 public void iterationPerformed(final IterationEvent e) {
540 doTestNormOfResidual(e);
541 }
542
543 public void iterationStarted(final IterationEvent e) {
544 doTestNormOfResidual(e);
545 }
546
547 public void terminationPerformed(final IterationEvent e) {
548 doTestNormOfResidual(e);
549 }
550 };
551 solver = new ConjugateGradient(maxIterations, 1E-10, true);
552 solver.getIterationManager().addIterationListener(listener);
553 final RealVector b = new ArrayRealVector(n);
554 for (int j = 0; j < n; j++) {
555 b.set(0.);
556 b.setEntry(j, 1.);
557 solver.solve(a, b);
558 }
559 }
560
561 @Test
562 public void testPreconditionedNormOfResidual() {
563 final int n = 5;
564 final int maxIterations = 100;
565 final RealLinearOperator a = new HilbertMatrix(n);
566 final RealLinearOperator m = JacobiPreconditioner.create(a);
567 final PreconditionedIterativeLinearSolver solver;
568 final IterationListener listener = new IterationListener() {
569
570 private void doTestNormOfResidual(final IterationEvent e) {
571 final IterativeLinearSolverEvent evt;
572 evt = (IterativeLinearSolverEvent) e;
573 final RealVector x = evt.getSolution();
574 final RealVector b = evt.getRightHandSideVector();
575 final RealVector r = b.subtract(a.operate(x));
576 final double rnorm = r.getNorm();
577 Assert.assertEquals("iteration performed (residual)",
578 rnorm, evt.getNormOfResidual(),
579 FastMath.max(1E-5 * rnorm, 1E-10));
580 }
581
582 public void initializationPerformed(final IterationEvent e) {
583 doTestNormOfResidual(e);
584 }
585
586 public void iterationPerformed(final IterationEvent e) {
587 doTestNormOfResidual(e);
588 }
589
590 public void iterationStarted(final IterationEvent e) {
591 doTestNormOfResidual(e);
592 }
593
594 public void terminationPerformed(final IterationEvent e) {
595 doTestNormOfResidual(e);
596 }
597 };
598 solver = new ConjugateGradient(maxIterations, 1E-10, true);
599 solver.getIterationManager().addIterationListener(listener);
600 final RealVector b = new ArrayRealVector(n);
601 for (int j = 0; j < n; j++) {
602 b.set(0.);
603 b.setEntry(j, 1.);
604 solver.solve(a, m, b);
605 }
606 }
607 }