View Javadoc
1   /*
2    * Licensed to the Hipparchus project under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The Hipparchus project licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.hipparchus.filtering.kalman.unscented;
18  
19  import org.hipparchus.exception.LocalizedCoreFormats;
20  import org.hipparchus.exception.MathIllegalArgumentException;
21  import org.hipparchus.exception.MathRuntimeException;
22  import org.hipparchus.filtering.kalman.KalmanFilter;
23  import org.hipparchus.filtering.kalman.Measurement;
24  import org.hipparchus.filtering.kalman.ProcessEstimate;
25  import org.hipparchus.linear.ArrayRealVector;
26  import org.hipparchus.linear.MatrixDecomposer;
27  import org.hipparchus.linear.MatrixUtils;
28  import org.hipparchus.linear.RealMatrix;
29  import org.hipparchus.linear.RealVector;
30  import org.hipparchus.util.UnscentedTransformProvider;
31  
32  /**
33   * Unscented Kalman filter for {@link UnscentedProcess unscented process}.
34   * @param <T> the type of the measurements
35   *
36   * @see "Wan, E. A., & Van Der Merwe, R. (2000, October). The unscented Kalman filter for nonlinear estimation.
37   *       In Proceedings of the IEEE 2000 Adaptive Systems for Signal Processing, Communications, and Control Symposium
38   *       (Cat. No. 00EX373) (pp. 153-158)"
39   * @since 2.2
40   */
41  public class UnscentedKalmanFilter<T extends Measurement> implements KalmanFilter<T> {
42  
43      /** Process to be estimated. */
44      private UnscentedProcess<T> process;
45  
46      /** Predicted state. */
47      private ProcessEstimate predicted;
48  
49      /** Corrected state. */
50      private ProcessEstimate corrected;
51  
52      /** Decompose to use for the correction phase. */
53      private final MatrixDecomposer decomposer;
54  
55      /** Number of estimated parameters. */
56      private final int n;
57  
58      /** Unscented transform provider. */
59      private final UnscentedTransformProvider utProvider;
60  
61      /** Simple constructor.
62       * @param decomposer decomposer to use for the correction phase
63       * @param process unscented process to estimate
64       * @param initialState initial state
65       * @param utProvider unscented transform provider
66       */
67      public UnscentedKalmanFilter(final MatrixDecomposer decomposer,
68                                   final UnscentedProcess<T> process,
69                                   final ProcessEstimate initialState,
70                                   final UnscentedTransformProvider utProvider) {
71          this.decomposer = decomposer;
72          this.process    = process;
73          this.corrected  = initialState;
74          this.n          = corrected.getState().getDimension();
75          this.utProvider = utProvider;
76          // Check state dimension
77          if (n == 0) {
78              // State dimension must be different from 0
79              throw new MathIllegalArgumentException(LocalizedCoreFormats.ZERO_STATE_SIZE);
80          }
81      }
82  
83      /** {@inheritDoc} */
84      @Override
85      public ProcessEstimate estimationStep(final T measurement) throws MathRuntimeException {
86  
87          // Calculate sigma points
88          final RealVector[] sigmaPoints = utProvider.unscentedTransform(corrected.getState(), corrected.getCovariance());
89  
90          // Perform the prediction and correction steps
91          return predictionAndCorrectionSteps(measurement, sigmaPoints);
92  
93      }
94  
95      /** This method perform the prediction and correction steps of the Unscented Kalman Filter.
96       * @param measurement single measurement to handle
97       * @param sigmaPoints computed sigma points
98       * @return estimated state after measurement has been considered
99       * @throws MathRuntimeException if matrix cannot be decomposed
100      */
101     public ProcessEstimate predictionAndCorrectionSteps(final T measurement, final RealVector[] sigmaPoints) throws MathRuntimeException {
102 
103         // Prediction phase
104         final UnscentedEvolution evolution = process.getEvolution(getCorrected().getTime(),
105                                                                   sigmaPoints, measurement);
106 
107         predict(evolution.getCurrentTime(), evolution.getCurrentStates(),
108                 evolution.getProcessNoiseMatrix());
109 
110         // Calculate sigma points from predicted state
111         final RealVector[] predictedSigmaPoints = utProvider.unscentedTransform(predicted.getState(),
112                                                                                 predicted.getCovariance());
113 
114         // Correction phase
115         final RealVector[] predictedMeasurements = process.getPredictedMeasurements(predictedSigmaPoints, measurement);
116         final RealVector   predictedMeasurement  = sum(predictedMeasurements, measurement.getValue().getDimension());
117         final RealMatrix   r                     = computeInnovationCovarianceMatrix(predictedMeasurements, predictedMeasurement, measurement.getCovariance());
118         final RealMatrix   crossCovarianceMatrix = computeCrossCovarianceMatrix(predictedSigmaPoints, predicted.getState(),
119                                                                                 predictedMeasurements, predictedMeasurement);
120         final RealVector   innovation            = (r == null) ? null : process.getInnovation(measurement, predictedMeasurement, predicted.getState(), r);
121         correct(measurement, r, crossCovarianceMatrix, innovation);
122         return getCorrected();
123 
124     }
125 
126     /** Perform prediction step.
127      * @param time process time
128      * @param predictedStates predicted state vectors
129      * @param noise process noise covariance matrix
130      */
131     private void predict(final double time, final RealVector[] predictedStates,  final RealMatrix noise) {
132 
133         // Computation of Eq. 17, weighted mean state
134         final RealVector predictedState = sum(predictedStates, n);
135 
136         // Computation of Eq. 18, predicted covariance matrix
137         final RealMatrix predictedCovariance = computeCovariance(predictedStates, predictedState).add(noise);
138 
139         predicted = new ProcessEstimate(time, predictedState, predictedCovariance);
140         corrected = null;
141 
142     }
143 
144     /** Perform correction step.
145      * @param measurement single measurement to handle
146      * @param innovationCovarianceMatrix innovation covariance matrix
147      * (may be null if measurement should be ignored)
148      * @param crossCovarianceMatrix cross covariance matrix
149      * @param innovation innovation
150      * (may be null if measurement should be ignored)
151      * @exception MathIllegalArgumentException if matrix cannot be decomposed
152      */
153     private void correct(final T measurement, final RealMatrix innovationCovarianceMatrix,
154                            final RealMatrix crossCovarianceMatrix, final RealVector innovation)
155         throws MathIllegalArgumentException {
156 
157         if (innovation == null) {
158             // measurement should be ignored
159             corrected = predicted;
160             return;
161         }
162 
163         // compute Kalman gain k
164         // the following is equivalent to k = P_cross * (R_pred)^-1
165         // we don't want to compute the inverse of a matrix,
166         // we start by post-multiplying by R_pred and get
167         // k.(R_pred) = P_cross
168         // then we transpose, knowing that R_pred is a symmetric matrix
169         // (R_pred).k^T = P_cross^T
170         // then we can use linear system solving instead of matrix inversion
171         final RealMatrix k = decomposer.
172                              decompose(innovationCovarianceMatrix).
173                              solve(crossCovarianceMatrix.transpose()).transpose();
174 
175         // correct state vector
176         final RealVector correctedState = predicted.getState().add(k.operate(innovation));
177 
178         // correct covariance matrix
179         final RealMatrix correctedCovariance = predicted.getCovariance().
180                                                subtract(k.multiply(innovationCovarianceMatrix).multiplyTransposed(k));
181 
182         corrected = new ProcessEstimate(measurement.getTime(), correctedState, correctedCovariance,
183                                         null, null, innovationCovarianceMatrix, k);
184 
185     }
186     /** Get the predicted state.
187      * @return predicted state
188      */
189     @Override
190     public ProcessEstimate getPredicted() {
191         return predicted;
192     }
193 
194     /** Get the corrected state.
195      * @return corrected state
196      */
197     @Override
198     public ProcessEstimate getCorrected() {
199         return corrected;
200     }
201 
202     /** Get the unscented transform provider.
203      * @return unscented transform provider
204      */
205     public UnscentedTransformProvider getUnscentedTransformProvider() {
206         return utProvider;
207     }
208 
209     /** Computes innovation covariance matrix.
210      * @param predictedMeasurements predicted measurements (one per sigma point)
211      * @param predictedMeasurement predicted measurements
212      *        (may be null if measurement should be ignored)
213      * @param r measurement covariance
214      * @return innovation covariance matrix (null if predictedMeasurement is null)
215      */
216     private RealMatrix computeInnovationCovarianceMatrix(final RealVector[] predictedMeasurements,
217                                                          final RealVector predictedMeasurement,
218                                                          final RealMatrix r) {
219         if (predictedMeasurement == null) {
220             return null;
221         }
222         // Computation of the innovation covariance matrix
223         final RealMatrix innovationCovarianceMatrix = computeCovariance(predictedMeasurements, predictedMeasurement);
224         // Add the measurement covariance
225         return innovationCovarianceMatrix.add(r);
226     }
227 
228     /**
229      * Computes cross covariance matrix.
230      * @param predictedStates predicted states
231      * @param predictedState predicted state
232      * @param predictedMeasurements current measurements
233      * @param predictedMeasurement predicted measurements
234      * @return cross covariance matrix
235      */
236     private RealMatrix computeCrossCovarianceMatrix(final RealVector[] predictedStates, final RealVector predictedState,
237                                                     final RealVector[] predictedMeasurements, final RealVector predictedMeasurement) {
238 
239         // Initialize the cross covariance matrix
240         RealMatrix crossCovarianceMatrix = MatrixUtils.createRealMatrix(predictedState.getDimension(),
241                                                                         predictedMeasurement.getDimension());
242 
243         // Covariance weights
244         final RealVector wc = utProvider.getWc();
245 
246         // Compute the cross covariance matrix
247         for (int i = 0; i <= 2 * n; i++) {
248             final RealVector stateDiff = predictedStates[i].subtract(predictedState);
249             final RealVector measDiff  = predictedMeasurements[i].subtract(predictedMeasurement);
250             crossCovarianceMatrix = crossCovarianceMatrix.add(outer(stateDiff, measDiff).scalarMultiply(wc.getEntry(i)));
251         }
252 
253         // Return the cross covariance
254         return crossCovarianceMatrix;
255     }
256 
257     /**
258      * Computes a weighted mean parameter from a given samples.
259      * <p>
260      * This method can be used for computing both the mean state and the mean measurement.
261      * <p>
262      * It corresponds to the Equation 17 of "Wan, E. A., & Van Der Merwe, R.
263      * The unscented Kalman filter for nonlinear estimation"
264      * </p>
265      * @param samples input samples
266      * @param size size of the weighted mean parameter
267      * @return weighted mean parameter
268      */
269     private RealVector sum(final RealVector[] samples, final int size) {
270 
271         // Initialize the weighted mean parameter
272         RealVector mean = new ArrayRealVector(size);
273 
274         // Mean weights
275         final RealVector wm = utProvider.getWm();
276 
277         // Compute weighted mean parameter
278         for (int i = 0; i <= 2 * n; i++) {
279             mean = mean.add(samples[i].mapMultiply(wm.getEntry(i)));
280         }
281 
282         // Return the weighted mean value
283         return mean;
284 
285     }
286 
287     /** Computes the covariance matrix.
288      * <p>
289      * This method can be used for computing both the predicted state
290      * covariance matrix and the innovation covariance matrix.
291      * <p>
292      * It corresponds to the Equation 18 of "Wan, E. A., & Van Der Merwe, R.
293      * The unscented Kalman filter for nonlinear estimation"
294      * </p>
295      * @param samples input samples
296      * @param state weighted mean parameter
297      * @return the covariance matrix
298      */
299     private RealMatrix computeCovariance(final RealVector[] samples,
300                                          final RealVector state) {
301 
302         // Initialize the covariance matrix, by using the size of the weighted mean parameter
303         final int dim = state.getDimension();
304         RealMatrix covarianceMatrix = MatrixUtils.createRealMatrix(dim, dim);
305 
306         // Covariance weights
307         final RealVector wc = utProvider.getWc();
308 
309         // Compute the covariance matrix
310         for (int i = 0; i <= 2 * n; i++) {
311             final RealVector diff = samples[i].subtract(state);
312             covarianceMatrix = covarianceMatrix.add(outer(diff, diff).scalarMultiply(wc.getEntry(i)));
313         }
314 
315         // Return the covariance
316         return covarianceMatrix;
317 
318     }
319 
320     /** Conputes the outer product of two vectors.
321      * @param a first vector
322      * @param b second vector
323      * @return the outer product of a and b
324      */
325     private RealMatrix outer(final RealVector a, final RealVector b) {
326 
327         // Initialize matrix
328         final RealMatrix outMatrix = MatrixUtils.createRealMatrix(a.getDimension(), b.getDimension());
329 
330         // Fill matrix
331         for (int row = 0; row < outMatrix.getRowDimension(); row++) {
332             for (int col = 0; col < outMatrix.getColumnDimension(); col++) {
333                 outMatrix.setEntry(row, col, a.getEntry(row) * b.getEntry(col));
334             }
335         }
336 
337         // Return
338         return outMatrix;
339 
340     }
341 
342 }