UnscentedKalmanFilter.java

/*
 * Licensed to the Hipparchus project under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The Hipparchus project licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.hipparchus.filtering.kalman.unscented;

import org.hipparchus.exception.LocalizedCoreFormats;
import org.hipparchus.exception.MathIllegalArgumentException;
import org.hipparchus.exception.MathRuntimeException;
import org.hipparchus.filtering.kalman.KalmanFilter;
import org.hipparchus.filtering.kalman.Measurement;
import org.hipparchus.filtering.kalman.ProcessEstimate;
import org.hipparchus.linear.MatrixDecomposer;
import org.hipparchus.linear.MatrixUtils;
import org.hipparchus.linear.RealMatrix;
import org.hipparchus.linear.RealVector;
import org.hipparchus.util.UnscentedTransformProvider;

/**
 * Unscented Kalman filter for {@link UnscentedProcess unscented process}.
 * @param <T> the type of the measurements
 *
 * @see "Wan, E. A., & Van Der Merwe, R. (2000, October). The unscented Kalman filter for nonlinear estimation.
 *       In Proceedings of the IEEE 2000 Adaptive Systems for Signal Processing, Communications, and Control Symposium
 *       (Cat. No. 00EX373) (pp. 153-158)"
 * @since 2.2
 */
public class UnscentedKalmanFilter<T extends Measurement> implements KalmanFilter<T> {

    /** Process to be estimated. */
    private UnscentedProcess<T> process;

    /** Predicted state. */
    private ProcessEstimate predicted;

    /** Corrected state. */
    private ProcessEstimate corrected;

    /** Decompose to use for the correction phase. */
    private final MatrixDecomposer decomposer;

    /** Number of estimated parameters. */
    private final int n;

    /** Unscented transform provider. */
    private final UnscentedTransformProvider utProvider;

    /** Simple constructor.
     * @param decomposer decomposer to use for the correction phase
     * @param process unscented process to estimate
     * @param initialState initial state
     * @param utProvider unscented transform provider
     */
    public UnscentedKalmanFilter(final MatrixDecomposer decomposer,
                                 final UnscentedProcess<T> process,
                                 final ProcessEstimate initialState,
                                 final UnscentedTransformProvider utProvider) {
        this.decomposer = decomposer;
        this.process    = process;
        this.corrected  = initialState;
        this.n          = corrected.getState().getDimension();
        this.utProvider = utProvider;
        // Check state dimension
        if (n == 0) {
            // State dimension must be different from 0
            throw new MathIllegalArgumentException(LocalizedCoreFormats.ZERO_STATE_SIZE);
        }
    }

    /** {@inheritDoc} */
    @Override
    public ProcessEstimate estimationStep(final T measurement) throws MathRuntimeException {

        // Calculate sigma points
        final RealVector[] sigmaPoints = utProvider.unscentedTransform(corrected.getState(), corrected.getCovariance());

        // Perform the prediction and correction steps
        return predictionAndCorrectionSteps(measurement, sigmaPoints);

    }

    /** This method perform the prediction and correction steps of the Unscented Kalman Filter.
     * @param measurement single measurement to handle
     * @param sigmaPoints computed sigma points
     * @return estimated state after measurement has been considered
     * @throws MathRuntimeException if matrix cannot be decomposed
     */
    public ProcessEstimate predictionAndCorrectionSteps(final T measurement, final RealVector[] sigmaPoints) throws MathRuntimeException {

        // Prediction phase
        final UnscentedEvolution evolution = process.getEvolution(getCorrected().getTime(),
                                                                  sigmaPoints, measurement);

        predict(evolution.getCurrentTime(), evolution.getCurrentStates(),
                evolution.getProcessNoiseMatrix());

        // Calculate sigma points from predicted state
        final RealVector[] predictedSigmaPoints = utProvider.unscentedTransform(predicted.getState(),
                                                                                predicted.getCovariance());

        // Correction phase
        final RealVector[] predictedMeasurements = process.getPredictedMeasurements(predictedSigmaPoints, measurement);
        final RealVector   predictedMeasurement  = utProvider.getUnscentedMeanState(predictedMeasurements);
        final RealMatrix   r                     = computeInnovationCovarianceMatrix(predictedMeasurements, predictedMeasurement, measurement.getCovariance());
        final RealMatrix   crossCovarianceMatrix = computeCrossCovarianceMatrix(predictedSigmaPoints, predicted.getState(),
                                                                                predictedMeasurements, predictedMeasurement);
        final RealVector   innovation            = (r == null) ? null : process.getInnovation(measurement, predictedMeasurement, predicted.getState(), r);
        correct(measurement, r, crossCovarianceMatrix, innovation);
        return getCorrected();

    }

    /** Perform prediction step.
     * @param time process time
     * @param predictedStates predicted state vectors
     * @param noise process noise covariance matrix
     */
    private void predict(final double time, final RealVector[] predictedStates,  final RealMatrix noise) {

        // Computation of Eq. 17, weighted mean state
        final RealVector predictedState = utProvider.getUnscentedMeanState(predictedStates);

        // Computation of Eq. 18, predicted covariance matrix
        final RealMatrix predictedCovariance = utProvider.getUnscentedCovariance(predictedStates, predictedState).add(noise);

        predicted = new ProcessEstimate(time, predictedState, predictedCovariance);
        corrected = null;

    }

    /** Perform correction step.
     * @param measurement single measurement to handle
     * @param innovationCovarianceMatrix innovation covariance matrix
     * (may be null if measurement should be ignored)
     * @param crossCovarianceMatrix cross covariance matrix
     * @param innovation innovation
     * (may be null if measurement should be ignored)
     * @exception MathIllegalArgumentException if matrix cannot be decomposed
     */
    private void correct(final T measurement, final RealMatrix innovationCovarianceMatrix,
                           final RealMatrix crossCovarianceMatrix, final RealVector innovation)
        throws MathIllegalArgumentException {

        if (innovation == null) {
            // measurement should be ignored
            corrected = predicted;
            return;
        }

        // compute Kalman gain k
        // the following is equivalent to k = P_cross * (R_pred)^-1
        // we don't want to compute the inverse of a matrix,
        // we start by post-multiplying by R_pred and get
        // k.(R_pred) = P_cross
        // then we transpose, knowing that R_pred is a symmetric matrix
        // (R_pred).k^T = P_cross^T
        // then we can use linear system solving instead of matrix inversion
        final RealMatrix k = decomposer.
                             decompose(innovationCovarianceMatrix).
                             solve(crossCovarianceMatrix.transpose()).transpose();

        // correct state vector
        final RealVector correctedState = predicted.getState().add(k.operate(innovation));

        // correct covariance matrix
        final RealMatrix correctedCovariance = predicted.getCovariance().
                                               subtract(k.multiply(innovationCovarianceMatrix).multiplyTransposed(k));

        corrected = new ProcessEstimate(measurement.getTime(), correctedState, correctedCovariance,
                                        null, null, innovationCovarianceMatrix, k);

    }
    /** Get the predicted state.
     * @return predicted state
     */
    @Override
    public ProcessEstimate getPredicted() {
        return predicted;
    }

    /** Get the corrected state.
     * @return corrected state
     */
    @Override
    public ProcessEstimate getCorrected() {
        return corrected;
    }

    /** Get the unscented transform provider.
     * @return unscented transform provider
     */
    public UnscentedTransformProvider getUnscentedTransformProvider() {
        return utProvider;
    }

    /** Computes innovation covariance matrix.
     * @param predictedMeasurements predicted measurements (one per sigma point)
     * @param predictedMeasurement predicted measurements
     *        (may be null if measurement should be ignored)
     * @param r measurement covariance
     * @return innovation covariance matrix (null if predictedMeasurement is null)
     */
    private RealMatrix computeInnovationCovarianceMatrix(final RealVector[] predictedMeasurements,
                                                         final RealVector predictedMeasurement,
                                                         final RealMatrix r) {
        if (predictedMeasurement == null) {
            return null;
        }
        // Computation of the innovation covariance matrix
        final RealMatrix innovationCovarianceMatrix = utProvider.getUnscentedCovariance(predictedMeasurements, predictedMeasurement);

        // Add the measurement covariance
        return innovationCovarianceMatrix.add(r);
    }

    /**
     * Computes cross covariance matrix.
     * @param predictedStates predicted states
     * @param predictedState predicted state
     * @param predictedMeasurements current measurements
     * @param predictedMeasurement predicted measurements
     * @return cross covariance matrix
     */
    private RealMatrix computeCrossCovarianceMatrix(final RealVector[] predictedStates, final RealVector predictedState,
                                                    final RealVector[] predictedMeasurements, final RealVector predictedMeasurement) {

        // Initialize the cross covariance matrix
        RealMatrix crossCovarianceMatrix = MatrixUtils.createRealMatrix(predictedState.getDimension(),
                                                                        predictedMeasurement.getDimension());

        // Covariance weights
        final RealVector wc = utProvider.getWc();

        // Compute the cross covariance matrix
        for (int i = 0; i <= 2 * n; i++) {
            final RealVector stateDiff = predictedStates[i].subtract(predictedState);
            final RealVector measDiff  = predictedMeasurements[i].subtract(predictedMeasurement);
            crossCovarianceMatrix = crossCovarianceMatrix.add(outer(stateDiff, measDiff).scalarMultiply(wc.getEntry(i)));
        }

        // Return the cross covariance
        return crossCovarianceMatrix;
    }

    /** Computes the outer product of two vectors.
     * @param a first vector
     * @param b second vector
     * @return the outer product of a and b
     */
    private RealMatrix outer(final RealVector a, final RealVector b) {

        // Initialize matrix
        final RealMatrix outMatrix = MatrixUtils.createRealMatrix(a.getDimension(), b.getDimension());

        // Fill matrix
        for (int row = 0; row < outMatrix.getRowDimension(); row++) {
            for (int col = 0; col < outMatrix.getColumnDimension(); col++) {
                outMatrix.setEntry(row, col, a.getEntry(row) * b.getEntry(col));
            }
        }

        // Return
        return outMatrix;
    }
}