KalmanSmoother.java
- /*
- * Licensed to the Hipparchus project under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The Hipparchus project licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- package org.hipparchus.filtering.kalman;
- import org.hipparchus.exception.MathIllegalStateException;
- import org.hipparchus.filtering.LocalizedFilterFormats;
- import org.hipparchus.linear.MatrixDecomposer;
- import org.hipparchus.linear.RealMatrix;
- import org.hipparchus.linear.RealVector;
- import java.util.ArrayList;
- import java.util.LinkedList;
- import java.util.List;
- /**
- * Kalman smoother for linear, extended or unscented filters.
- * <p>
- * This implementation is attached to a filter using the observer mechanism. Once all measurements have been
- * processed by the filter, the smoothing method can be called.
- * </p>
- * <p>
- * For example
- * </p>
- * <pre>{@code
- * // Kalman filter
- * final KalmanFilter<SimpleMeasurement> filter = new LinearKalmanFilter<>(decomposer, process, initialState);
- *
- * // Smoother observer
- * final KalmanSmoother smoother = new KalmanSmoother(decomposer);
- * filter.setObserver(smoother);
- *
- * // Process measurements with filter (forwards pass)
- * measurements.forEach(filter::estimationStep);
- *
- * // Smooth backwards
- * List<ProcessEstimate> smoothedStates = smoother.backwardsSmooth();
- * }</pre>
- *
- * @see "Särkkä, S. Bayesian Filtering and Smoothing. Cambridge 2013"
- */
- public class KalmanSmoother implements KalmanObserver {
- /** Decomposer to use for gain calculation. */
- private final MatrixDecomposer decomposer;
- /** Storage for smoother gain matrices. */
- private final List<SmootherData> smootherData;
- /** Simple constructor.
- * @param decomposer decomposer to use for the smoother gain calculations
- */
- public KalmanSmoother(final MatrixDecomposer decomposer) {
- this.decomposer = decomposer;
- this.smootherData = new ArrayList<>();
- }
- @Override
- public void init(KalmanEstimate estimate) {
- // Add initial state to smoother data
- smootherData.add(new SmootherData(
- estimate.getCorrected().getTime(),
- null,
- null,
- estimate.getCorrected().getState(),
- estimate.getCorrected().getCovariance(),
- null
- ));
- }
- @Override
- public void updatePerformed(KalmanEstimate estimate) {
- // Smoother gain
- // We want G = D * P^(-1)
- // Calculate with G = (P^(-1) * D^T)^T
- final RealMatrix smootherGain = decomposer
- .decompose(estimate.getPredicted().getCovariance())
- .solve(estimate.getStateCrossCovariance().transpose())
- .transpose();
- smootherData.add(new SmootherData(
- estimate.getCorrected().getTime(),
- estimate.getPredicted().getState(),
- estimate.getPredicted().getCovariance(),
- estimate.getCorrected().getState(),
- estimate.getCorrected().getCovariance(),
- smootherGain
- ));
- }
- /** Backwards smooth.
- * This is a backward pass over the filtered data, recursively calculating smoothed states, using the
- * Rauch-Tung-Striebel (RTS) formulation.
- * Note that the list result is a `LinkedList`, not an `ArrayList`.
- * @return list of smoothed states
- */
- public List<ProcessEstimate> backwardsSmooth() {
- // Check for at least one measurement
- if (smootherData.size() < 2) {
- throw new MathIllegalStateException(LocalizedFilterFormats.PROCESS_AT_LEAST_ONE_MEASUREMENT);
- }
- // Initialise output
- final LinkedList<ProcessEstimate> smootherResults = new LinkedList<>();
- // Last smoothed state is the same as the filtered state
- final SmootherData lastUpdate = smootherData.get(smootherData.size() - 1);
- ProcessEstimate smoothedState = new ProcessEstimate(lastUpdate.getTime(),
- lastUpdate.getCorrectedState(), lastUpdate.getCorrectedCovariance());
- smootherResults.addFirst(smoothedState);
- // Backwards recursion on the smoothed state
- for (int i = smootherData.size() - 2; i >= 0; --i) {
- // These are from equation 8.6 in Sarkka, "Bayesian Filtering and Smoothing", Cambridge, 2013.
- final RealMatrix smootherGain = smootherData.get(i + 1).getSmootherGain();
- final RealVector smoothedMean = smootherData.get(i).getCorrectedState()
- .add(smootherGain.operate(smoothedState.getState()
- .subtract(smootherData.get(i + 1).getPredictedState())));
- final RealMatrix smoothedCovariance = smootherData.get(i).getCorrectedCovariance()
- .add(smootherGain.multiply(smoothedState.getCovariance()
- .subtract(smootherData.get(i + 1).getPredictedCovariance()))
- .multiplyTransposed(smootherGain));
- // Populate smoothed state
- smoothedState = new ProcessEstimate(smootherData.get(i).getTime(), smoothedMean, smoothedCovariance);
- smootherResults.addFirst(smoothedState);
- }
- return smootherResults;
- }
- /** Container for smoother data. */
- private static class SmootherData {
- /** Process time (typically the time or index of a measurement). */
- private final double time;
- /** Predicted state vector. */
- private final RealVector predictedState;
- /** Predicted covariance. */
- private final RealMatrix predictedCovariance;
- /** Corrected state vector. */
- private final RealVector correctedState;
- /** Corrected covariance. */
- private final RealMatrix correctedCovariance;
- /** Smoother gain. */
- private final RealMatrix smootherGain;
- SmootherData(final double time,
- final RealVector predictedState,
- final RealMatrix predictedCovariance,
- final RealVector correctedState,
- final RealMatrix correctedCovariance,
- final RealMatrix smootherGain) {
- this.time = time;
- this.predictedState = predictedState;
- this.predictedCovariance = predictedCovariance;
- this.correctedState = correctedState;
- this.correctedCovariance = correctedCovariance;
- this.smootherGain = smootherGain;
- }
- /** Get the process time.
- * @return process time (typically the time or index of a measurement)
- */
- public double getTime() {
- return time;
- }
- /**
- * Get predicted state
- * @return predicted state
- */
- public RealVector getPredictedState() {
- return predictedState;
- }
- /**
- * Get predicted covariance
- * @return predicted covariance
- */
- public RealMatrix getPredictedCovariance() {
- return predictedCovariance;
- }
- /**
- * Get corrected state
- * @return corrected state
- */
- public RealVector getCorrectedState() {
- return correctedState;
- }
- /**
- * Get corrected covariance
- * @return corrected covariance
- */
- public RealMatrix getCorrectedCovariance() {
- return correctedCovariance;
- }
- /**
- * Get smoother gain (for previous time-step)
- * @return smoother gain
- */
- public RealMatrix getSmootherGain() {
- return smootherGain;
- }
- }
- }