1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.hipparchus.filtering.kalman.unscented;
18
19 import org.hipparchus.exception.LocalizedCoreFormats;
20 import org.hipparchus.exception.MathIllegalArgumentException;
21 import org.hipparchus.exception.MathRuntimeException;
22 import org.hipparchus.filtering.kalman.KalmanFilter;
23 import org.hipparchus.filtering.kalman.Measurement;
24 import org.hipparchus.filtering.kalman.ProcessEstimate;
25 import org.hipparchus.linear.ArrayRealVector;
26 import org.hipparchus.linear.MatrixDecomposer;
27 import org.hipparchus.linear.MatrixUtils;
28 import org.hipparchus.linear.RealMatrix;
29 import org.hipparchus.linear.RealVector;
30 import org.hipparchus.util.UnscentedTransformProvider;
31
32
33
34
35
36
37
38
39
40
41 public class UnscentedKalmanFilter<T extends Measurement> implements KalmanFilter<T> {
42
43
44 private UnscentedProcess<T> process;
45
46
47 private ProcessEstimate predicted;
48
49
50 private ProcessEstimate corrected;
51
52
53 private final MatrixDecomposer decomposer;
54
55
56 private final int n;
57
58
59 private final UnscentedTransformProvider utProvider;
60
61
62
63
64
65
66
67 public UnscentedKalmanFilter(final MatrixDecomposer decomposer,
68 final UnscentedProcess<T> process,
69 final ProcessEstimate initialState,
70 final UnscentedTransformProvider utProvider) {
71 this.decomposer = decomposer;
72 this.process = process;
73 this.corrected = initialState;
74 this.n = corrected.getState().getDimension();
75 this.utProvider = utProvider;
76
77 if (n == 0) {
78
79 throw new MathIllegalArgumentException(LocalizedCoreFormats.ZERO_STATE_SIZE);
80 }
81 }
82
83
84 @Override
85 public ProcessEstimate estimationStep(final T measurement) throws MathRuntimeException {
86
87
88 final RealVector[] sigmaPoints = utProvider.unscentedTransform(corrected.getState(), corrected.getCovariance());
89
90
91 return predictionAndCorrectionSteps(measurement, sigmaPoints);
92
93 }
94
95
96
97
98
99
100
101 public ProcessEstimate predictionAndCorrectionSteps(final T measurement, final RealVector[] sigmaPoints) throws MathRuntimeException {
102
103
104 final UnscentedEvolution evolution = process.getEvolution(getCorrected().getTime(),
105 sigmaPoints, measurement);
106
107 predict(evolution.getCurrentTime(), evolution.getCurrentStates(),
108 evolution.getProcessNoiseMatrix());
109
110
111 final RealVector[] predictedSigmaPoints = utProvider.unscentedTransform(predicted.getState(),
112 predicted.getCovariance());
113
114
115 final RealVector[] predictedMeasurements = process.getPredictedMeasurements(predictedSigmaPoints, measurement);
116 final RealVector predictedMeasurement = sum(predictedMeasurements, measurement.getValue().getDimension());
117 final RealMatrix r = computeInnovationCovarianceMatrix(predictedMeasurements, predictedMeasurement, measurement.getCovariance());
118 final RealMatrix crossCovarianceMatrix = computeCrossCovarianceMatrix(predictedSigmaPoints, predicted.getState(),
119 predictedMeasurements, predictedMeasurement);
120 final RealVector innovation = (r == null) ? null : process.getInnovation(measurement, predictedMeasurement, predicted.getState(), r);
121 correct(measurement, r, crossCovarianceMatrix, innovation);
122 return getCorrected();
123
124 }
125
126
127
128
129
130
131 private void predict(final double time, final RealVector[] predictedStates, final RealMatrix noise) {
132
133
134 final RealVector predictedState = sum(predictedStates, n);
135
136
137 final RealMatrix predictedCovariance = computeCovariance(predictedStates, predictedState).add(noise);
138
139 predicted = new ProcessEstimate(time, predictedState, predictedCovariance);
140 corrected = null;
141
142 }
143
144
145
146
147
148
149
150
151
152
153 private void correct(final T measurement, final RealMatrix innovationCovarianceMatrix,
154 final RealMatrix crossCovarianceMatrix, final RealVector innovation)
155 throws MathIllegalArgumentException {
156
157 if (innovation == null) {
158
159 corrected = predicted;
160 return;
161 }
162
163
164
165
166
167
168
169
170
171 final RealMatrix k = decomposer.
172 decompose(innovationCovarianceMatrix).
173 solve(crossCovarianceMatrix.transpose()).transpose();
174
175
176 final RealVector correctedState = predicted.getState().add(k.operate(innovation));
177
178
179 final RealMatrix correctedCovariance = predicted.getCovariance().
180 subtract(k.multiply(innovationCovarianceMatrix).multiplyTransposed(k));
181
182 corrected = new ProcessEstimate(measurement.getTime(), correctedState, correctedCovariance,
183 null, null, innovationCovarianceMatrix, k);
184
185 }
186
187
188
189 @Override
190 public ProcessEstimate getPredicted() {
191 return predicted;
192 }
193
194
195
196
197 @Override
198 public ProcessEstimate getCorrected() {
199 return corrected;
200 }
201
202
203
204
205 public UnscentedTransformProvider getUnscentedTransformProvider() {
206 return utProvider;
207 }
208
209
210
211
212
213
214
215
216 private RealMatrix computeInnovationCovarianceMatrix(final RealVector[] predictedMeasurements,
217 final RealVector predictedMeasurement,
218 final RealMatrix r) {
219 if (predictedMeasurement == null) {
220 return null;
221 }
222
223 final RealMatrix innovationCovarianceMatrix = computeCovariance(predictedMeasurements, predictedMeasurement);
224
225 return innovationCovarianceMatrix.add(r);
226 }
227
228
229
230
231
232
233
234
235
236 private RealMatrix computeCrossCovarianceMatrix(final RealVector[] predictedStates, final RealVector predictedState,
237 final RealVector[] predictedMeasurements, final RealVector predictedMeasurement) {
238
239
240 RealMatrix crossCovarianceMatrix = MatrixUtils.createRealMatrix(predictedState.getDimension(),
241 predictedMeasurement.getDimension());
242
243
244 final RealVector wc = utProvider.getWc();
245
246
247 for (int i = 0; i <= 2 * n; i++) {
248 final RealVector stateDiff = predictedStates[i].subtract(predictedState);
249 final RealVector measDiff = predictedMeasurements[i].subtract(predictedMeasurement);
250 crossCovarianceMatrix = crossCovarianceMatrix.add(outer(stateDiff, measDiff).scalarMultiply(wc.getEntry(i)));
251 }
252
253
254 return crossCovarianceMatrix;
255 }
256
257
258
259
260
261
262
263
264
265
266
267
268
269 private RealVector sum(final RealVector[] samples, final int size) {
270
271
272 RealVector mean = new ArrayRealVector(size);
273
274
275 final RealVector wm = utProvider.getWm();
276
277
278 for (int i = 0; i <= 2 * n; i++) {
279 mean = mean.add(samples[i].mapMultiply(wm.getEntry(i)));
280 }
281
282
283 return mean;
284
285 }
286
287
288
289
290
291
292
293
294
295
296
297
298
299 private RealMatrix computeCovariance(final RealVector[] samples,
300 final RealVector state) {
301
302
303 final int dim = state.getDimension();
304 RealMatrix covarianceMatrix = MatrixUtils.createRealMatrix(dim, dim);
305
306
307 final RealVector wc = utProvider.getWc();
308
309
310 for (int i = 0; i <= 2 * n; i++) {
311 final RealVector diff = samples[i].subtract(state);
312 covarianceMatrix = covarianceMatrix.add(outer(diff, diff).scalarMultiply(wc.getEntry(i)));
313 }
314
315
316 return covarianceMatrix;
317
318 }
319
320
321
322
323
324
325 private RealMatrix outer(final RealVector a, final RealVector b) {
326
327
328 final RealMatrix outMatrix = MatrixUtils.createRealMatrix(a.getDimension(), b.getDimension());
329
330
331 for (int row = 0; row < outMatrix.getRowDimension(); row++) {
332 for (int col = 0; col < outMatrix.getColumnDimension(); col++) {
333 outMatrix.setEntry(row, col, a.getEntry(row) * b.getEntry(col));
334 }
335 }
336
337
338 return outMatrix;
339
340 }
341
342 }