KalmanSmoother.java

  1. /*
  2.  * Licensed to the Hipparchus project under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The Hipparchus project licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      https://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.hipparchus.filtering.kalman;

  18. import org.hipparchus.exception.MathIllegalStateException;
  19. import org.hipparchus.filtering.LocalizedFilterFormats;
  20. import org.hipparchus.linear.MatrixDecomposer;
  21. import org.hipparchus.linear.RealMatrix;
  22. import org.hipparchus.linear.RealVector;

  23. import java.util.ArrayList;
  24. import java.util.LinkedList;
  25. import java.util.List;

  26. /**
  27.  * Kalman smoother for linear, extended or unscented filters.
  28.  * <p>
  29.  * This implementation is attached to a filter using the observer mechanism.  Once all measurements have been
  30.  * processed by the filter, the smoothing method can be called.
  31.  * </p>
  32.  * <p>
  33.  * For example
  34.  * </p>
  35.  * <pre>{@code
  36.  *     // Kalman filter
  37.  *     final KalmanFilter<SimpleMeasurement> filter = new LinearKalmanFilter<>(decomposer, process, initialState);
  38.  *
  39.  *     // Smoother observer
  40.  *     final KalmanSmoother smoother = new KalmanSmoother(decomposer);
  41.  *     filter.setObserver(smoother);
  42.  *
  43.  *     // Process measurements with filter (forwards pass)
  44.  *     measurements.forEach(filter::estimationStep);
  45.  *
  46.  *     // Smooth backwards
  47.  *     List<ProcessEstimate> smoothedStates = smoother.backwardsSmooth();
  48.  * }</pre>
  49.  *
  50.  * @see "Särkkä, S. Bayesian Filtering and Smoothing. Cambridge 2013"
  51.  */
  52. public class KalmanSmoother implements KalmanObserver {

  53.     /** Decomposer to use for gain calculation. */
  54.     private final MatrixDecomposer decomposer;

  55.     /** Storage for smoother gain matrices. */
  56.     private final List<SmootherData> smootherData;

  57.     /** Simple constructor.
  58.      * @param decomposer decomposer to use for the smoother gain calculations
  59.      */
  60.     public KalmanSmoother(final MatrixDecomposer decomposer) {
  61.         this.decomposer = decomposer;
  62.         this.smootherData = new ArrayList<>();
  63.     }

  64.     @Override
  65.     public void init(KalmanEstimate estimate) {
  66.         // Add initial state to smoother data
  67.         smootherData.add(new SmootherData(
  68.                 estimate.getCorrected().getTime(),
  69.                 null,
  70.                 null,
  71.                 estimate.getCorrected().getState(),
  72.                 estimate.getCorrected().getCovariance(),
  73.                 null
  74.         ));

  75.     }

  76.     @Override
  77.     public void updatePerformed(KalmanEstimate estimate) {
  78.         // Smoother gain
  79.         // We want G = D * P^(-1)
  80.         // Calculate with G = (P^(-1) * D^T)^T
  81.         final RealMatrix smootherGain = decomposer
  82.                 .decompose(estimate.getPredicted().getCovariance())
  83.                 .solve(estimate.getStateCrossCovariance().transpose())
  84.                 .transpose();
  85.         smootherData.add(new SmootherData(
  86.                 estimate.getCorrected().getTime(),
  87.                 estimate.getPredicted().getState(),
  88.                 estimate.getPredicted().getCovariance(),
  89.                 estimate.getCorrected().getState(),
  90.                 estimate.getCorrected().getCovariance(),
  91.                 smootherGain
  92.         ));
  93.     }

  94.     /** Backwards smooth.
  95.      * This is a backward pass over the filtered data, recursively calculating smoothed states, using the
  96.      * Rauch-Tung-Striebel (RTS) formulation.
  97.      * Note that the list result is a `LinkedList`, not an `ArrayList`.
  98.      * @return list of smoothed states
  99.      */
  100.     public List<ProcessEstimate> backwardsSmooth() {
  101.         // Check for at least one measurement
  102.         if (smootherData.size() < 2) {
  103.             throw new MathIllegalStateException(LocalizedFilterFormats.PROCESS_AT_LEAST_ONE_MEASUREMENT);
  104.         }

  105.         // Initialise output
  106.         final LinkedList<ProcessEstimate> smootherResults = new LinkedList<>();

  107.         // Last smoothed state is the same as the filtered state
  108.         final SmootherData lastUpdate = smootherData.get(smootherData.size() - 1);
  109.         ProcessEstimate smoothedState = new ProcessEstimate(lastUpdate.getTime(),
  110.                 lastUpdate.getCorrectedState(), lastUpdate.getCorrectedCovariance());
  111.         smootherResults.addFirst(smoothedState);

  112.         // Backwards recursion on the smoothed state
  113.         for (int i = smootherData.size() - 2; i >= 0; --i) {

  114.             // These are from equation 8.6 in Sarkka, "Bayesian Filtering and Smoothing", Cambridge, 2013.
  115.             final RealMatrix smootherGain = smootherData.get(i + 1).getSmootherGain();

  116.             final RealVector smoothedMean = smootherData.get(i).getCorrectedState()
  117.                     .add(smootherGain.operate(smoothedState.getState()
  118.                             .subtract(smootherData.get(i + 1).getPredictedState())));

  119.             final RealMatrix smoothedCovariance = smootherData.get(i).getCorrectedCovariance()
  120.                     .add(smootherGain.multiply(smoothedState.getCovariance()
  121.                                     .subtract(smootherData.get(i + 1).getPredictedCovariance()))
  122.                             .multiplyTransposed(smootherGain));

  123.             // Populate smoothed state
  124.             smoothedState = new ProcessEstimate(smootherData.get(i).getTime(), smoothedMean, smoothedCovariance);
  125.             smootherResults.addFirst(smoothedState);
  126.         }

  127.         return smootherResults;
  128.     }

  129.     /** Container for smoother data. */
  130.     private static class SmootherData {
  131.         /** Process time (typically the time or index of a measurement). */
  132.         private final double time;

  133.         /** Predicted state vector. */
  134.         private final RealVector predictedState;

  135.         /** Predicted covariance. */
  136.         private final RealMatrix predictedCovariance;

  137.         /** Corrected state vector. */
  138.         private final RealVector correctedState;

  139.         /** Corrected covariance. */
  140.         private final RealMatrix correctedCovariance;

  141.         /** Smoother gain. */
  142.         private final RealMatrix smootherGain;

  143.         SmootherData(final double time,
  144.                      final RealVector predictedState,
  145.                      final RealMatrix predictedCovariance,
  146.                      final RealVector correctedState,
  147.                      final RealMatrix correctedCovariance,
  148.                      final RealMatrix smootherGain) {
  149.             this.time = time;
  150.             this.predictedState = predictedState;
  151.             this.predictedCovariance = predictedCovariance;
  152.             this.correctedState = correctedState;
  153.             this.correctedCovariance = correctedCovariance;
  154.             this.smootherGain = smootherGain;
  155.         }

  156.         /** Get the process time.
  157.          * @return process time (typically the time or index of a measurement)
  158.          */
  159.         public double getTime() {
  160.             return time;
  161.         }

  162.         /**
  163.          * Get predicted state
  164.          * @return predicted state
  165.          */
  166.         public RealVector getPredictedState() {
  167.             return predictedState;
  168.         }

  169.         /**
  170.          * Get predicted covariance
  171.          * @return predicted covariance
  172.          */
  173.         public RealMatrix getPredictedCovariance() {
  174.             return predictedCovariance;
  175.         }

  176.         /**
  177.          * Get corrected state
  178.          * @return corrected state
  179.          */
  180.         public RealVector getCorrectedState() {
  181.             return correctedState;
  182.         }

  183.         /**
  184.          * Get corrected covariance
  185.          * @return corrected covariance
  186.          */
  187.         public RealMatrix getCorrectedCovariance() {
  188.             return correctedCovariance;
  189.         }

  190.         /**
  191.          * Get smoother gain (for previous time-step)
  192.          * @return smoother gain
  193.          */
  194.         public RealMatrix getSmootherGain() {
  195.             return smootherGain;
  196.         }
  197.     }

  198. }