UnscentedKalmanFilter.java

  1. /*
  2.  * Licensed to the Hipparchus project under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The Hipparchus project licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      https://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.hipparchus.filtering.kalman.unscented;

  18. import org.hipparchus.exception.LocalizedCoreFormats;
  19. import org.hipparchus.exception.MathIllegalArgumentException;
  20. import org.hipparchus.exception.MathRuntimeException;
  21. import org.hipparchus.filtering.kalman.KalmanFilter;
  22. import org.hipparchus.filtering.kalman.KalmanObserver;
  23. import org.hipparchus.filtering.kalman.Measurement;
  24. import org.hipparchus.filtering.kalman.ProcessEstimate;
  25. import org.hipparchus.linear.MatrixDecomposer;
  26. import org.hipparchus.linear.MatrixUtils;
  27. import org.hipparchus.linear.RealMatrix;
  28. import org.hipparchus.linear.RealVector;
  29. import org.hipparchus.util.UnscentedTransformProvider;

  30. /**
  31.  * Unscented Kalman filter for {@link UnscentedProcess unscented process}.
  32.  * @param <T> the type of the measurements
  33.  *
  34.  * @see "Wan, E. A., & Van Der Merwe, R. (2000, October). The unscented Kalman filter for nonlinear estimation.
  35.  *       In Proceedings of the IEEE 2000 Adaptive Systems for Signal Processing, Communications, and Control Symposium
  36.  *       (Cat. No. 00EX373) (pp. 153-158)"
  37.  * @since 2.2
  38.  */
  39. public class UnscentedKalmanFilter<T extends Measurement> implements KalmanFilter<T> {

  40.     /** Process to be estimated. */
  41.     private final UnscentedProcess<T> process;

  42.     /** Predicted state. */
  43.     private ProcessEstimate predicted;

  44.     /** Corrected state. */
  45.     private ProcessEstimate corrected;

  46.     /** Decompose to use for the correction phase. */
  47.     private final MatrixDecomposer decomposer;

  48.     /** Number of estimated parameters. */
  49.     private final int n;

  50.     /** Unscented transform provider. */
  51.     private final UnscentedTransformProvider utProvider;

  52.     /** Prior corrected sigma-points. */
  53.     private RealVector[] priorSigmaPoints;

  54.     /** Predicted sigma-points. */
  55.     private RealVector[] predictedNoNoiseSigmaPoints;

  56.     /** Observer. */
  57.     private KalmanObserver observer;

  58.     /** Simple constructor.
  59.      * @param decomposer decomposer to use for the correction phase
  60.      * @param process unscented process to estimate
  61.      * @param initialState initial state
  62.      * @param utProvider unscented transform provider
  63.      */
  64.     public UnscentedKalmanFilter(final MatrixDecomposer decomposer,
  65.                                  final UnscentedProcess<T> process,
  66.                                  final ProcessEstimate initialState,
  67.                                  final UnscentedTransformProvider utProvider) {
  68.         this.decomposer = decomposer;
  69.         this.process    = process;
  70.         this.corrected  = initialState;
  71.         this.n          = corrected.getState().getDimension();
  72.         this.utProvider = utProvider;
  73.         this.priorSigmaPoints = null;
  74.         this.predictedNoNoiseSigmaPoints = null;
  75.         this.observer = null;

  76.         // Check state dimension
  77.         if (n == 0) {
  78.             // State dimension must be different from 0
  79.             throw new MathIllegalArgumentException(LocalizedCoreFormats.ZERO_STATE_SIZE);
  80.         }
  81.     }

  82.     /** {@inheritDoc} */
  83.     @Override
  84.     public ProcessEstimate estimationStep(final T measurement) throws MathRuntimeException {

  85.         // Calculate sigma points
  86.         final RealVector[] sigmaPoints = utProvider.unscentedTransform(corrected.getState(), corrected.getCovariance());
  87.         priorSigmaPoints = sigmaPoints;

  88.         // Perform the prediction and correction steps
  89.         return predictionAndCorrectionSteps(measurement, sigmaPoints);

  90.     }

  91.     /** This method perform the prediction and correction steps of the Unscented Kalman Filter.
  92.      * @param measurement single measurement to handle
  93.      * @param sigmaPoints computed sigma points
  94.      * @return estimated state after measurement has been considered
  95.      * @throws MathRuntimeException if matrix cannot be decomposed
  96.      */
  97.     private ProcessEstimate predictionAndCorrectionSteps(final T measurement, final RealVector[] sigmaPoints) throws MathRuntimeException {

  98.         // Prediction phase
  99.         final UnscentedEvolution evolution = process.getEvolution(getCorrected().getTime(),
  100.                                                                   sigmaPoints, measurement);
  101.         predictedNoNoiseSigmaPoints = evolution.getCurrentStates();

  102.         // Computation of Eq. 17, weighted mean state
  103.         final RealVector predictedState = utProvider.getUnscentedMeanState(evolution.getCurrentStates());

  104.         // Calculate process noise
  105.         final RealMatrix processNoiseMatrix = process.getProcessNoiseMatrix(getCorrected().getTime(), predictedState,
  106.                                                                             measurement);

  107.         predict(evolution.getCurrentTime(), evolution.getCurrentStates(), processNoiseMatrix);

  108.         // Calculate sigma points from predicted state
  109.         final RealVector[] predictedSigmaPoints = utProvider.unscentedTransform(predicted.getState(),
  110.                                                                                 predicted.getCovariance());

  111.         // Correction phase
  112.         final RealVector[] predictedMeasurements = process.getPredictedMeasurements(predictedSigmaPoints, measurement);
  113.         final RealVector   predictedMeasurement  = utProvider.getUnscentedMeanState(predictedMeasurements);
  114.         final RealMatrix   r                     = computeInnovationCovarianceMatrix(predictedMeasurements, predictedMeasurement, measurement.getCovariance());
  115.         final RealMatrix   crossCovarianceMatrix = computeCrossCovarianceMatrix(predictedSigmaPoints, predicted.getState(),
  116.                                                                                 predictedMeasurements, predictedMeasurement);
  117.         final RealVector   innovation            = (r == null) ? null : process.getInnovation(measurement, predictedMeasurement, predicted.getState(), r);
  118.         correct(measurement, r, crossCovarianceMatrix, innovation);

  119.         if (observer != null) {
  120.             observer.updatePerformed(this);
  121.         }
  122.         return getCorrected();

  123.     }

  124.     /** Perform prediction step.
  125.      * @param time process time
  126.      * @param predictedStates predicted state vectors
  127.      * @param noise process noise covariance matrix
  128.      */
  129.     private void predict(final double time, final RealVector[] predictedStates, final RealMatrix noise) {

  130.         // Computation of Eq. 17, weighted mean state
  131.         final RealVector predictedState = utProvider.getUnscentedMeanState(predictedStates);

  132.         // Computation of Eq. 18, predicted covariance matrix
  133.         final RealMatrix predictedCovariance = utProvider.getUnscentedCovariance(predictedStates, predictedState).add(noise);

  134.         predicted = new ProcessEstimate(time, predictedState, predictedCovariance);
  135.         corrected = null;

  136.     }

  137.     /** Perform correction step.
  138.      * @param measurement single measurement to handle
  139.      * @param innovationCovarianceMatrix innovation covariance matrix
  140.      * (may be null if measurement should be ignored)
  141.      * @param crossCovarianceMatrix cross covariance matrix
  142.      * @param innovation innovation
  143.      * (may be null if measurement should be ignored)
  144.      * @exception MathIllegalArgumentException if matrix cannot be decomposed
  145.      */
  146.     private void correct(final T measurement, final RealMatrix innovationCovarianceMatrix,
  147.                            final RealMatrix crossCovarianceMatrix, final RealVector innovation)
  148.         throws MathIllegalArgumentException {

  149.         if (innovation == null) {
  150.             // measurement should be ignored
  151.             corrected = predicted;
  152.             return;
  153.         }

  154.         // compute Kalman gain k
  155.         // the following is equivalent to k = P_cross * (R_pred)^-1
  156.         // we don't want to compute the inverse of a matrix,
  157.         // we start by post-multiplying by R_pred and get
  158.         // k.(R_pred) = P_cross
  159.         // then we transpose, knowing that R_pred is a symmetric matrix
  160.         // (R_pred).k^T = P_cross^T
  161.         // then we can use linear system solving instead of matrix inversion
  162.         final RealMatrix k = decomposer.
  163.                              decompose(innovationCovarianceMatrix).
  164.                              solve(crossCovarianceMatrix.transpose()).transpose();

  165.         // correct state vector
  166.         final RealVector correctedState = predicted.getState().add(k.operate(innovation));

  167.         // correct covariance matrix
  168.         final RealMatrix correctedCovariance = predicted.getCovariance().
  169.                                                subtract(k.multiply(innovationCovarianceMatrix).multiplyTransposed(k));

  170.         corrected = new ProcessEstimate(measurement.getTime(), correctedState, correctedCovariance,
  171.                                         null, null, innovationCovarianceMatrix, k);

  172.     }

  173.     /** {@inheritDoc} */
  174.     @Override
  175.     public void setObserver(final KalmanObserver kalmanObserver) {
  176.         observer = kalmanObserver;
  177.         observer.init(this);
  178.     }

  179.     /** Get the predicted state.
  180.      * @return predicted state
  181.      */
  182.     @Override
  183.     public ProcessEstimate getPredicted() {
  184.         return predicted;
  185.     }

  186.     /** Get the corrected state.
  187.      * @return corrected state
  188.      */
  189.     @Override
  190.     public ProcessEstimate getCorrected() {
  191.         return corrected;
  192.     }

  193.     /** {@inheritDoc} */
  194.     @Override
  195.     public RealMatrix getStateCrossCovariance() {
  196.         final RealVector priorState = utProvider.getUnscentedMeanState(priorSigmaPoints);
  197.         final RealVector predictedState = utProvider.getUnscentedMeanState(predictedNoNoiseSigmaPoints);

  198.         return computeCrossCovarianceMatrix(priorSigmaPoints, priorState, predictedNoNoiseSigmaPoints, predictedState);
  199.     }

  200.     /** Get the unscented transform provider.
  201.      * @return unscented transform provider
  202.      */
  203.     public UnscentedTransformProvider getUnscentedTransformProvider() {
  204.         return utProvider;
  205.     }

  206.     /** Computes innovation covariance matrix.
  207.      * @param predictedMeasurements predicted measurements (one per sigma point)
  208.      * @param predictedMeasurement predicted measurements
  209.      *        (may be null if measurement should be ignored)
  210.      * @param r measurement covariance
  211.      * @return innovation covariance matrix (null if predictedMeasurement is null)
  212.      */
  213.     private RealMatrix computeInnovationCovarianceMatrix(final RealVector[] predictedMeasurements,
  214.                                                          final RealVector predictedMeasurement,
  215.                                                          final RealMatrix r) {
  216.         if (predictedMeasurement == null) {
  217.             return null;
  218.         }
  219.         // Computation of the innovation covariance matrix
  220.         final RealMatrix innovationCovarianceMatrix = utProvider.getUnscentedCovariance(predictedMeasurements, predictedMeasurement);

  221.         // Add the measurement covariance
  222.         return innovationCovarianceMatrix.add(r);
  223.     }

  224.     /**
  225.      * Computes cross covariance matrix.
  226.      * @param predictedStates predicted states
  227.      * @param predictedState predicted state
  228.      * @param predictedMeasurements current measurements
  229.      * @param predictedMeasurement predicted measurements
  230.      * @return cross covariance matrix
  231.      */
  232.     private RealMatrix computeCrossCovarianceMatrix(final RealVector[] predictedStates, final RealVector predictedState,
  233.                                                     final RealVector[] predictedMeasurements, final RealVector predictedMeasurement) {

  234.         // Initialize the cross covariance matrix
  235.         RealMatrix crossCovarianceMatrix = MatrixUtils.createRealMatrix(predictedState.getDimension(),
  236.                                                                         predictedMeasurement.getDimension());

  237.         // Covariance weights
  238.         final RealVector wc = utProvider.getWc();

  239.         // Compute the cross covariance matrix
  240.         for (int i = 0; i <= 2 * n; i++) {
  241.             final RealVector stateDiff = predictedStates[i].subtract(predictedState);
  242.             final RealVector measDiff  = predictedMeasurements[i].subtract(predictedMeasurement);
  243.             crossCovarianceMatrix = crossCovarianceMatrix.add(stateDiff.outerProduct(measDiff).scalarMultiply(wc.getEntry(i)));
  244.         }

  245.         // Return the cross covariance
  246.         return crossCovarianceMatrix;
  247.     }

  248. }