KalmanSmoother.java
/*
* Licensed to the Hipparchus project under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The Hipparchus project licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.hipparchus.filtering.kalman;
import org.hipparchus.exception.MathIllegalStateException;
import org.hipparchus.filtering.LocalizedFilterFormats;
import org.hipparchus.linear.MatrixDecomposer;
import org.hipparchus.linear.RealMatrix;
import org.hipparchus.linear.RealVector;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
/**
* Kalman smoother for linear, extended or unscented filters.
* <p>
* This implementation is attached to a filter using the observer mechanism. Once all measurements have been
* processed by the filter, the smoothing method can be called.
* </p>
* <p>
* For example
* </p>
* <pre>{@code
* // Kalman filter
* final KalmanFilter<SimpleMeasurement> filter = new LinearKalmanFilter<>(decomposer, process, initialState);
*
* // Smoother observer
* final KalmanSmoother smoother = new KalmanSmoother(decomposer);
* filter.setObserver(smoother);
*
* // Process measurements with filter (forwards pass)
* measurements.forEach(filter::estimationStep);
*
* // Smooth backwards
* List<ProcessEstimate> smoothedStates = smoother.backwardsSmooth();
* }</pre>
*
* @see "Särkkä, S. Bayesian Filtering and Smoothing. Cambridge 2013"
*/
public class KalmanSmoother implements KalmanObserver {
/** Decomposer to use for gain calculation. */
private final MatrixDecomposer decomposer;
/** Storage for smoother gain matrices. */
private final List<SmootherData> smootherData;
/** Simple constructor.
* @param decomposer decomposer to use for the smoother gain calculations
*/
public KalmanSmoother(final MatrixDecomposer decomposer) {
this.decomposer = decomposer;
this.smootherData = new ArrayList<>();
}
@Override
public void init(KalmanEstimate estimate) {
// Add initial state to smoother data
smootherData.add(new SmootherData(
estimate.getCorrected().getTime(),
null,
null,
estimate.getCorrected().getState(),
estimate.getCorrected().getCovariance(),
null
));
}
@Override
public void updatePerformed(KalmanEstimate estimate) {
// Smoother gain
// We want G = D * P^(-1)
// Calculate with G = (P^(-1) * D^T)^T
final RealMatrix smootherGain = decomposer
.decompose(estimate.getPredicted().getCovariance())
.solve(estimate.getStateCrossCovariance().transpose())
.transpose();
smootherData.add(new SmootherData(
estimate.getCorrected().getTime(),
estimate.getPredicted().getState(),
estimate.getPredicted().getCovariance(),
estimate.getCorrected().getState(),
estimate.getCorrected().getCovariance(),
smootherGain
));
}
/** Backwards smooth.
* This is a backward pass over the filtered data, recursively calculating smoothed states, using the
* Rauch-Tung-Striebel (RTS) formulation.
* Note that the list result is a `LinkedList`, not an `ArrayList`.
* @return list of smoothed states
*/
public List<ProcessEstimate> backwardsSmooth() {
// Check for at least one measurement
if (smootherData.size() < 2) {
throw new MathIllegalStateException(LocalizedFilterFormats.PROCESS_AT_LEAST_ONE_MEASUREMENT);
}
// Initialise output
final LinkedList<ProcessEstimate> smootherResults = new LinkedList<>();
// Last smoothed state is the same as the filtered state
final SmootherData lastUpdate = smootherData.get(smootherData.size() - 1);
ProcessEstimate smoothedState = new ProcessEstimate(lastUpdate.getTime(),
lastUpdate.getCorrectedState(), lastUpdate.getCorrectedCovariance());
smootherResults.addFirst(smoothedState);
// Backwards recursion on the smoothed state
for (int i = smootherData.size() - 2; i >= 0; --i) {
// These are from equation 8.6 in Sarkka, "Bayesian Filtering and Smoothing", Cambridge, 2013.
final RealMatrix smootherGain = smootherData.get(i + 1).getSmootherGain();
final RealVector smoothedMean = smootherData.get(i).getCorrectedState()
.add(smootherGain.operate(smoothedState.getState()
.subtract(smootherData.get(i + 1).getPredictedState())));
final RealMatrix smoothedCovariance = smootherData.get(i).getCorrectedCovariance()
.add(smootherGain.multiply(smoothedState.getCovariance()
.subtract(smootherData.get(i + 1).getPredictedCovariance()))
.multiplyTransposed(smootherGain));
// Populate smoothed state
smoothedState = new ProcessEstimate(smootherData.get(i).getTime(), smoothedMean, smoothedCovariance);
smootherResults.addFirst(smoothedState);
}
return smootherResults;
}
/** Container for smoother data. */
private static class SmootherData {
/** Process time (typically the time or index of a measurement). */
private final double time;
/** Predicted state vector. */
private final RealVector predictedState;
/** Predicted covariance. */
private final RealMatrix predictedCovariance;
/** Corrected state vector. */
private final RealVector correctedState;
/** Corrected covariance. */
private final RealMatrix correctedCovariance;
/** Smoother gain. */
private final RealMatrix smootherGain;
SmootherData(final double time,
final RealVector predictedState,
final RealMatrix predictedCovariance,
final RealVector correctedState,
final RealMatrix correctedCovariance,
final RealMatrix smootherGain) {
this.time = time;
this.predictedState = predictedState;
this.predictedCovariance = predictedCovariance;
this.correctedState = correctedState;
this.correctedCovariance = correctedCovariance;
this.smootherGain = smootherGain;
}
/** Get the process time.
* @return process time (typically the time or index of a measurement)
*/
public double getTime() {
return time;
}
/**
* Get predicted state
* @return predicted state
*/
public RealVector getPredictedState() {
return predictedState;
}
/**
* Get predicted covariance
* @return predicted covariance
*/
public RealMatrix getPredictedCovariance() {
return predictedCovariance;
}
/**
* Get corrected state
* @return corrected state
*/
public RealVector getCorrectedState() {
return correctedState;
}
/**
* Get corrected covariance
* @return corrected covariance
*/
public RealMatrix getCorrectedCovariance() {
return correctedCovariance;
}
/**
* Get smoother gain (for previous time-step)
* @return smoother gain
*/
public RealMatrix getSmootherGain() {
return smootherGain;
}
}
}