UnscentedTransformProvider.java
/*
* Licensed to the Hipparchus project under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The Hipparchus project licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.hipparchus.util;
import org.hipparchus.linear.ArrayRealVector;
import org.hipparchus.linear.MatrixUtils;
import org.hipparchus.linear.RealMatrix;
import org.hipparchus.linear.RealVector;
/**
* Provider for unscented transform.
* @since 2.2
*/
public interface UnscentedTransformProvider {
/**
* Perform the unscented transform from a state and its covariance.
* @param state process state
* @param covariance covariance associated with the process state
* @return an array containing the sigma points of the unscented transform
*/
RealVector[] unscentedTransform(RealVector state, RealMatrix covariance);
/**
* Computes a weighted mean state from a given set of sigma points.
* <p>
* This method can be used for computing both the mean state and the mean measurement
* in an Unscented Kalman filter.
* </p>
* <p>
* It corresponds to Equation 17 of "Wan, E. A., & Van Der Merwe, R. The unscented Kalman filter for nonlinear estimation"
* </p>
* @param sigmaPoints input samples
* @return weighted mean state
*/
default RealVector getUnscentedMeanState(RealVector[] sigmaPoints) {
// Sigma point dimension
final int sigmaPointDimension = sigmaPoints[0].getDimension();
// Compute weighted mean
// ---------------------
RealVector weightedMean = new ArrayRealVector(sigmaPointDimension);
// Compute the weight coefficients wm
final RealVector wm = getWm();
// Weight each sigma point and sum them
for (int i = 0; i < sigmaPoints.length; i++) {
weightedMean = weightedMean.add(sigmaPoints[i].mapMultiply(wm.getEntry(i)));
}
return weightedMean;
}
/** Computes the unscented covariance matrix from a weighted mean state and a set of sigma points.
* <p>
* This method can be used for computing both the predicted state
* covariance matrix and the innovation covariance matrix in an Unscented Kalman filter.
* </p>
* <p>
* It corresponds to Equation 18 of "Wan, E. A., & Van Der Merwe, R. The unscented Kalman filter for nonlinear estimation"
* </p>
* @param sigmaPoints input sigma points
* @param meanState weighted mean state
* @return the unscented covariance matrix
*/
default RealMatrix getUnscentedCovariance(RealVector[] sigmaPoints, RealVector meanState) {
// State dimension
final int stateDimension = meanState.getDimension();
// Compute covariance matrix
// -------------------------
RealMatrix covarianceMatrix = MatrixUtils.createRealMatrix(stateDimension, stateDimension);
// Compute the weight coefficients wc
final RealVector wc = getWc();
// Reconstruct the covariance
for (int i = 0; i < sigmaPoints.length; i++) {
final RealMatrix diff = MatrixUtils.createColumnRealMatrix(sigmaPoints[i].subtract(meanState).toArray());
covarianceMatrix = covarianceMatrix.add(diff.multiplyTransposed(diff).scalarMultiply(wc.getEntry(i)));
}
return covarianceMatrix;
}
/**
* Perform the inverse unscented transform from an array of sigma points.
* @param sigmaPoints array containing the sigma points of the unscented transform
* @return mean state and associated covariance
*/
default Pair<RealVector, RealMatrix> inverseUnscentedTransform(RealVector[] sigmaPoints) {
// Mean state
final RealVector meanState = getUnscentedMeanState(sigmaPoints);
// Return state and covariance
return new Pair<>(meanState, getUnscentedCovariance(sigmaPoints, meanState));
}
/**
* Get the covariance weights.
* @return the covariance weights
*/
RealVector getWc();
/**
* Get the mean weights.
* @return the mean weights
*/
RealVector getWm();
}