1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.hipparchus.filtering.kalman.unscented;
18
19 import org.hipparchus.exception.LocalizedCoreFormats;
20 import org.hipparchus.exception.MathIllegalArgumentException;
21 import org.hipparchus.exception.MathRuntimeException;
22 import org.hipparchus.filtering.kalman.KalmanFilter;
23 import org.hipparchus.filtering.kalman.Measurement;
24 import org.hipparchus.filtering.kalman.ProcessEstimate;
25 import org.hipparchus.linear.MatrixDecomposer;
26 import org.hipparchus.linear.MatrixUtils;
27 import org.hipparchus.linear.RealMatrix;
28 import org.hipparchus.linear.RealVector;
29 import org.hipparchus.util.UnscentedTransformProvider;
30
31
32
33
34
35
36
37
38
39
40 public class UnscentedKalmanFilter<T extends Measurement> implements KalmanFilter<T> {
41
42
43 private UnscentedProcess<T> process;
44
45
46 private ProcessEstimate predicted;
47
48
49 private ProcessEstimate corrected;
50
51
52 private final MatrixDecomposer decomposer;
53
54
55 private final int n;
56
57
58 private final UnscentedTransformProvider utProvider;
59
60
61
62
63
64
65
66 public UnscentedKalmanFilter(final MatrixDecomposer decomposer,
67 final UnscentedProcess<T> process,
68 final ProcessEstimate initialState,
69 final UnscentedTransformProvider utProvider) {
70 this.decomposer = decomposer;
71 this.process = process;
72 this.corrected = initialState;
73 this.n = corrected.getState().getDimension();
74 this.utProvider = utProvider;
75
76 if (n == 0) {
77
78 throw new MathIllegalArgumentException(LocalizedCoreFormats.ZERO_STATE_SIZE);
79 }
80 }
81
82
83 @Override
84 public ProcessEstimate estimationStep(final T measurement) throws MathRuntimeException {
85
86
87 final RealVector[] sigmaPoints = utProvider.unscentedTransform(corrected.getState(), corrected.getCovariance());
88
89
90 return predictionAndCorrectionSteps(measurement, sigmaPoints);
91
92 }
93
94
95
96
97
98
99
100 public ProcessEstimate predictionAndCorrectionSteps(final T measurement, final RealVector[] sigmaPoints) throws MathRuntimeException {
101
102
103 final UnscentedEvolution evolution = process.getEvolution(getCorrected().getTime(),
104 sigmaPoints, measurement);
105
106 predict(evolution.getCurrentTime(), evolution.getCurrentStates(),
107 evolution.getProcessNoiseMatrix());
108
109
110 final RealVector[] predictedSigmaPoints = utProvider.unscentedTransform(predicted.getState(),
111 predicted.getCovariance());
112
113
114 final RealVector[] predictedMeasurements = process.getPredictedMeasurements(predictedSigmaPoints, measurement);
115 final RealVector predictedMeasurement = utProvider.getUnscentedMeanState(predictedMeasurements);
116 final RealMatrix r = computeInnovationCovarianceMatrix(predictedMeasurements, predictedMeasurement, measurement.getCovariance());
117 final RealMatrix crossCovarianceMatrix = computeCrossCovarianceMatrix(predictedSigmaPoints, predicted.getState(),
118 predictedMeasurements, predictedMeasurement);
119 final RealVector innovation = (r == null) ? null : process.getInnovation(measurement, predictedMeasurement, predicted.getState(), r);
120 correct(measurement, r, crossCovarianceMatrix, innovation);
121 return getCorrected();
122
123 }
124
125
126
127
128
129
130 private void predict(final double time, final RealVector[] predictedStates, final RealMatrix noise) {
131
132
133 final RealVector predictedState = utProvider.getUnscentedMeanState(predictedStates);
134
135
136 final RealMatrix predictedCovariance = utProvider.getUnscentedCovariance(predictedStates, predictedState).add(noise);
137
138 predicted = new ProcessEstimate(time, predictedState, predictedCovariance);
139 corrected = null;
140
141 }
142
143
144
145
146
147
148
149
150
151
152 private void correct(final T measurement, final RealMatrix innovationCovarianceMatrix,
153 final RealMatrix crossCovarianceMatrix, final RealVector innovation)
154 throws MathIllegalArgumentException {
155
156 if (innovation == null) {
157
158 corrected = predicted;
159 return;
160 }
161
162
163
164
165
166
167
168
169
170 final RealMatrix k = decomposer.
171 decompose(innovationCovarianceMatrix).
172 solve(crossCovarianceMatrix.transpose()).transpose();
173
174
175 final RealVector correctedState = predicted.getState().add(k.operate(innovation));
176
177
178 final RealMatrix correctedCovariance = predicted.getCovariance().
179 subtract(k.multiply(innovationCovarianceMatrix).multiplyTransposed(k));
180
181 corrected = new ProcessEstimate(measurement.getTime(), correctedState, correctedCovariance,
182 null, null, innovationCovarianceMatrix, k);
183
184 }
185
186
187
188 @Override
189 public ProcessEstimate getPredicted() {
190 return predicted;
191 }
192
193
194
195
196 @Override
197 public ProcessEstimate getCorrected() {
198 return corrected;
199 }
200
201
202
203
204 public UnscentedTransformProvider getUnscentedTransformProvider() {
205 return utProvider;
206 }
207
208
209
210
211
212
213
214
215 private RealMatrix computeInnovationCovarianceMatrix(final RealVector[] predictedMeasurements,
216 final RealVector predictedMeasurement,
217 final RealMatrix r) {
218 if (predictedMeasurement == null) {
219 return null;
220 }
221
222 final RealMatrix innovationCovarianceMatrix = utProvider.getUnscentedCovariance(predictedMeasurements, predictedMeasurement);
223
224
225 return innovationCovarianceMatrix.add(r);
226 }
227
228
229
230
231
232
233
234
235
236 private RealMatrix computeCrossCovarianceMatrix(final RealVector[] predictedStates, final RealVector predictedState,
237 final RealVector[] predictedMeasurements, final RealVector predictedMeasurement) {
238
239
240 RealMatrix crossCovarianceMatrix = MatrixUtils.createRealMatrix(predictedState.getDimension(),
241 predictedMeasurement.getDimension());
242
243
244 final RealVector wc = utProvider.getWc();
245
246
247 for (int i = 0; i <= 2 * n; i++) {
248 final RealVector stateDiff = predictedStates[i].subtract(predictedState);
249 final RealVector measDiff = predictedMeasurements[i].subtract(predictedMeasurement);
250 crossCovarianceMatrix = crossCovarianceMatrix.add(outer(stateDiff, measDiff).scalarMultiply(wc.getEntry(i)));
251 }
252
253
254 return crossCovarianceMatrix;
255 }
256
257
258
259
260
261
262 private RealMatrix outer(final RealVector a, final RealVector b) {
263
264
265 final RealMatrix outMatrix = MatrixUtils.createRealMatrix(a.getDimension(), b.getDimension());
266
267
268 for (int row = 0; row < outMatrix.getRowDimension(); row++) {
269 for (int col = 0; col < outMatrix.getColumnDimension(); col++) {
270 outMatrix.setEntry(row, col, a.getEntry(row) * b.getEntry(col));
271 }
272 }
273
274
275 return outMatrix;
276 }
277 }