1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22 package org.hipparchus.optim.nonlinear.vector.leastsquares;
23
24 import org.hipparchus.analysis.MultivariateMatrixFunction;
25 import org.hipparchus.analysis.MultivariateVectorFunction;
26 import org.hipparchus.exception.MathIllegalStateException;
27 import org.hipparchus.linear.Array2DRowRealMatrix;
28 import org.hipparchus.linear.ArrayRealVector;
29 import org.hipparchus.linear.DiagonalMatrix;
30 import org.hipparchus.linear.EigenDecompositionSymmetric;
31 import org.hipparchus.linear.RealMatrix;
32 import org.hipparchus.linear.RealVector;
33 import org.hipparchus.optim.AbstractOptimizationProblem;
34 import org.hipparchus.optim.ConvergenceChecker;
35 import org.hipparchus.optim.LocalizedOptimFormats;
36 import org.hipparchus.optim.PointVectorValuePair;
37 import org.hipparchus.optim.nonlinear.vector.leastsquares.LeastSquaresProblem.Evaluation;
38 import org.hipparchus.util.FastMath;
39 import org.hipparchus.util.Incrementor;
40 import org.hipparchus.util.Pair;
41
42
43
44
45
46 public class LeastSquaresFactory {
47
48
49 private LeastSquaresFactory() {}
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68 public static LeastSquaresProblem create(final MultivariateJacobianFunction model,
69 final RealVector observed,
70 final RealVector start,
71 final RealMatrix weight,
72 final ConvergenceChecker<Evaluation> checker,
73 final int maxEvaluations,
74 final int maxIterations,
75 final boolean lazyEvaluation,
76 final ParameterValidator paramValidator) {
77 final LeastSquaresProblem p = new LocalLeastSquaresProblem(model,
78 observed,
79 start,
80 checker,
81 maxEvaluations,
82 maxIterations,
83 lazyEvaluation,
84 paramValidator);
85 if (weight != null) {
86 return weightMatrix(p, weight);
87 } else {
88 return p;
89 }
90 }
91
92
93
94
95
96
97
98
99
100
101
102
103
104 public static LeastSquaresProblem create(final MultivariateJacobianFunction model,
105 final RealVector observed,
106 final RealVector start,
107 final ConvergenceChecker<Evaluation> checker,
108 final int maxEvaluations,
109 final int maxIterations) {
110 return create(model,
111 observed,
112 start,
113 null,
114 checker,
115 maxEvaluations,
116 maxIterations,
117 false,
118 null);
119 }
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134 public static LeastSquaresProblem create(final MultivariateJacobianFunction model,
135 final RealVector observed,
136 final RealVector start,
137 final RealMatrix weight,
138 final ConvergenceChecker<Evaluation> checker,
139 final int maxEvaluations,
140 final int maxIterations) {
141 return weightMatrix(create(model,
142 observed,
143 start,
144 checker,
145 maxEvaluations,
146 maxIterations),
147 weight);
148 }
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169 public static LeastSquaresProblem create(final MultivariateVectorFunction model,
170 final MultivariateMatrixFunction jacobian,
171 final double[] observed,
172 final double[] start,
173 final RealMatrix weight,
174 final ConvergenceChecker<Evaluation> checker,
175 final int maxEvaluations,
176 final int maxIterations) {
177 return create(model(model, jacobian),
178 new ArrayRealVector(observed, false),
179 new ArrayRealVector(start, false),
180 weight,
181 checker,
182 maxEvaluations,
183 maxIterations);
184 }
185
186
187
188
189
190
191
192
193
194 public static LeastSquaresProblem weightMatrix(final LeastSquaresProblem problem,
195 final RealMatrix weights) {
196 final RealMatrix weightSquareRoot = squareRoot(weights);
197 return new LeastSquaresAdapter(problem) {
198
199 @Override
200 public Evaluation evaluate(final RealVector point) {
201 return new DenseWeightedEvaluation(super.evaluate(point), weightSquareRoot);
202 }
203 };
204 }
205
206
207
208
209
210
211
212
213
214 public static LeastSquaresProblem weightDiagonal(final LeastSquaresProblem problem,
215 final RealVector weights) {
216
217 return weightMatrix(problem, new DiagonalMatrix(weights.toArray()));
218 }
219
220
221
222
223
224
225
226
227
228
229 public static LeastSquaresProblem countEvaluations(final LeastSquaresProblem problem,
230 final Incrementor counter) {
231 return new LeastSquaresAdapter(problem) {
232
233
234 @Override
235 public Evaluation evaluate(final RealVector point) {
236 counter.increment();
237 return super.evaluate(point);
238 }
239
240
241 };
242 }
243
244
245
246
247
248
249
250
251 public static ConvergenceChecker<Evaluation> evaluationChecker(final ConvergenceChecker<PointVectorValuePair> checker) {
252 return new ConvergenceChecker<Evaluation>() {
253
254 @Override
255 public boolean converged(final int iteration,
256 final Evaluation previous,
257 final Evaluation current) {
258 return checker.converged(
259 iteration,
260 new PointVectorValuePair(
261 previous.getPoint().toArray(),
262 previous.getResiduals().toArray(),
263 false),
264 new PointVectorValuePair(
265 current.getPoint().toArray(),
266 current.getResiduals().toArray(),
267 false)
268 );
269 }
270 };
271 }
272
273
274
275
276
277
278
279 private static RealMatrix squareRoot(final RealMatrix m) {
280 if (m instanceof DiagonalMatrix) {
281 final int dim = m.getRowDimension();
282 final RealMatrix sqrtM = new DiagonalMatrix(dim);
283 for (int i = 0; i < dim; i++) {
284 sqrtM.setEntry(i, i, FastMath.sqrt(m.getEntry(i, i)));
285 }
286 return sqrtM;
287 } else {
288 final EigenDecompositionSymmetric dec = new EigenDecompositionSymmetric(m);
289 return dec.getSquareRoot();
290 }
291 }
292
293
294
295
296
297
298
299
300
301 public static MultivariateJacobianFunction model(final MultivariateVectorFunction value,
302 final MultivariateMatrixFunction jacobian) {
303 return new LocalValueAndJacobianFunction(value, jacobian);
304 }
305
306
307
308
309
310 private static class LocalValueAndJacobianFunction
311 implements ValueAndJacobianFunction {
312
313 private final MultivariateVectorFunction value;
314
315 private final MultivariateMatrixFunction jacobian;
316
317
318
319
320
321 LocalValueAndJacobianFunction(final MultivariateVectorFunction value,
322 final MultivariateMatrixFunction jacobian) {
323 this.value = value;
324 this.jacobian = jacobian;
325 }
326
327
328 @Override
329 public Pair<RealVector, RealMatrix> value(final RealVector point) {
330
331 final double[] p = point.toArray();
332
333
334 return new Pair<RealVector, RealMatrix>(computeValue(p),
335 computeJacobian(p));
336 }
337
338
339 @Override
340 public RealVector computeValue(final double[] params) {
341 return new ArrayRealVector(value.value(params), false);
342 }
343
344
345 @Override
346 public RealMatrix computeJacobian(final double[] params) {
347 return new Array2DRowRealMatrix(jacobian.value(params), false);
348 }
349 }
350
351
352
353
354
355
356 private static class LocalLeastSquaresProblem
357 extends AbstractOptimizationProblem<Evaluation>
358 implements LeastSquaresProblem {
359
360
361 private final RealVector target;
362
363 private final MultivariateJacobianFunction model;
364
365 private final RealVector start;
366
367 private final boolean lazyEvaluation;
368
369 private final ParameterValidator paramValidator;
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384 LocalLeastSquaresProblem(final MultivariateJacobianFunction model,
385 final RealVector target,
386 final RealVector start,
387 final ConvergenceChecker<Evaluation> checker,
388 final int maxEvaluations,
389 final int maxIterations,
390 final boolean lazyEvaluation,
391 final ParameterValidator paramValidator) {
392 super(maxEvaluations, maxIterations, checker);
393 this.target = target;
394 this.model = model;
395 this.start = start;
396 this.lazyEvaluation = lazyEvaluation;
397 this.paramValidator = paramValidator;
398
399 if (lazyEvaluation &&
400 !(model instanceof ValueAndJacobianFunction)) {
401
402
403 throw new MathIllegalStateException(LocalizedOptimFormats.INVALID_IMPLEMENTATION,
404 model.getClass().getName());
405 }
406 }
407
408
409 @Override
410 public int getObservationSize() {
411 return target.getDimension();
412 }
413
414
415 @Override
416 public int getParameterSize() {
417 return start.getDimension();
418 }
419
420
421 @Override
422 public RealVector getStart() {
423 return start == null ? null : start.copy();
424 }
425
426
427 @Override
428 public Evaluation evaluate(final RealVector point) {
429
430 final RealVector p = paramValidator == null ?
431 point.copy() :
432 paramValidator.validate(point.copy());
433
434 if (lazyEvaluation) {
435 return new LazyUnweightedEvaluation((ValueAndJacobianFunction) model,
436 target,
437 p);
438 } else {
439
440 final Pair<RealVector, RealMatrix> value = model.value(p);
441 return new UnweightedEvaluation(value.getFirst(),
442 value.getSecond(),
443 target,
444 p);
445 }
446 }
447
448
449
450
451 private static class UnweightedEvaluation extends AbstractEvaluation {
452
453 private final RealVector point;
454
455 private final RealMatrix jacobian;
456
457 private final RealVector residuals;
458
459
460
461
462
463
464
465
466
467 private UnweightedEvaluation(final RealVector values,
468 final RealMatrix jacobian,
469 final RealVector target,
470 final RealVector point) {
471 super(target.getDimension());
472 this.jacobian = jacobian;
473 this.point = point;
474 this.residuals = target.subtract(values);
475 }
476
477
478 @Override
479 public RealMatrix getJacobian() {
480 return jacobian;
481 }
482
483
484 @Override
485 public RealVector getPoint() {
486 return point;
487 }
488
489
490 @Override
491 public RealVector getResiduals() {
492 return residuals;
493 }
494 }
495
496
497
498
499 private static class LazyUnweightedEvaluation extends AbstractEvaluation {
500
501 private final RealVector point;
502
503 private final ValueAndJacobianFunction model;
504
505 private final RealVector target;
506
507
508
509
510
511
512
513
514 private LazyUnweightedEvaluation(final ValueAndJacobianFunction model,
515 final RealVector target,
516 final RealVector point) {
517 super(target.getDimension());
518
519 this.model = model;
520 this.point = point;
521 this.target = target;
522 }
523
524
525 @Override
526 public RealMatrix getJacobian() {
527 return model.computeJacobian(point.toArray());
528 }
529
530
531 @Override
532 public RealVector getPoint() {
533 return point;
534 }
535
536
537 @Override
538 public RealVector getResiduals() {
539 return target.subtract(model.computeValue(point.toArray()));
540 }
541 }
542 }
543 }
544