MultivariateNormalMixtureExpectationMaximization.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      https://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. /*
  18.  * This is not the original file distributed by the Apache Software Foundation
  19.  * It has been modified by the Hipparchus project
  20.  */
  21. package org.hipparchus.stat.fitting;

  22. import java.util.ArrayList;
  23. import java.util.Arrays;
  24. import java.util.List;

  25. import org.hipparchus.distribution.multivariate.MixtureMultivariateNormalDistribution;
  26. import org.hipparchus.distribution.multivariate.MultivariateNormalDistribution;
  27. import org.hipparchus.exception.LocalizedCoreFormats;
  28. import org.hipparchus.exception.MathIllegalArgumentException;
  29. import org.hipparchus.exception.MathIllegalStateException;
  30. import org.hipparchus.linear.Array2DRowRealMatrix;
  31. import org.hipparchus.linear.RealMatrix;
  32. import org.hipparchus.stat.correlation.Covariance;
  33. import org.hipparchus.util.FastMath;
  34. import org.hipparchus.util.MathArrays;
  35. import org.hipparchus.util.Pair;

  36. /**
  37.  * Expectation-Maximization algorithm for fitting the parameters of
  38.  * multivariate normal mixture model distributions.
  39.  *
  40.  * This implementation is pure original code based on <a
  41.  * href="https://www.ee.washington.edu/techsite/papers/documents/UWEETR-2010-0002.pdf">
  42.  * EM Demystified: An Expectation-Maximization Tutorial</a> by Yihua Chen and Maya R. Gupta,
  43.  * Department of Electrical Engineering, University of Washington, Seattle, WA 98195.
  44.  * It was verified using external tools like <a
  45.  * href="http://cran.r-project.org/web/packages/mixtools/index.html">CRAN Mixtools</a>
  46.  * (see the JUnit test cases) but it is <strong>not</strong> based on Mixtools code at all.
  47.  * The discussion of the origin of this class can be seen in the comments of the <a
  48.  * href="https://issues.apache.org/jira/browse/MATH-817">MATH-817</a> JIRA issue.
  49.  */
  50. public class MultivariateNormalMixtureExpectationMaximization {
  51.     /** Default maximum number of iterations allowed per fitting process. */
  52.     private static final int DEFAULT_MAX_ITERATIONS = 1000;
  53.     /** Default convergence threshold for fitting. */
  54.     private static final double DEFAULT_THRESHOLD = 1E-5;
  55.     /** The data to fit. */
  56.     private final double[][] data;
  57.     /** The model fit against the data. */
  58.     private MixtureMultivariateNormalDistribution fittedModel;
  59.     /** The log likelihood of the data given the fitted model. */
  60.     private double logLikelihood;

  61.     /**
  62.      * Creates an object to fit a multivariate normal mixture model to data.
  63.      *
  64.      * @param data Data to use in fitting procedure
  65.      * @throws MathIllegalArgumentException if data has no rows
  66.      * @throws MathIllegalArgumentException if rows of data have different numbers
  67.      * of columns
  68.      * @throws MathIllegalArgumentException if the number of columns in the data is
  69.      * less than 2
  70.      */
  71.     public MultivariateNormalMixtureExpectationMaximization(double[][] data)
  72.         throws MathIllegalArgumentException {
  73.         if (data.length < 1) {
  74.             throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL,
  75.                                                    data.length, 1);
  76.         }

  77.         this.data = new double[data.length][data[0].length];

  78.         for (int i = 0; i < data.length; i++) {
  79.             if (data[i].length != data[0].length) {
  80.                 // Jagged arrays not allowed
  81.                 throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH,
  82.                                                        data[i].length, data[0].length);
  83.             }
  84.             if (data[i].length < 2) {
  85.                 throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL,
  86.                                                     data[i].length, 2, true);
  87.             }
  88.             this.data[i] = data[i].clone();
  89.         }
  90.     }

  91.     /**
  92.      * Fit a mixture model to the data supplied to the constructor.
  93.      *
  94.      * The quality of the fit depends on the concavity of the data provided to
  95.      * the constructor and the initial mixture provided to this function. If the
  96.      * data has many local optima, multiple runs of the fitting function with
  97.      * different initial mixtures may be required to find the optimal solution.
  98.      * If a MathIllegalArgumentException is encountered, it is possible that another
  99.      * initialization would work.
  100.      *
  101.      * @param initialMixture Model containing initial values of weights and
  102.      * multivariate normals
  103.      * @param maxIterations Maximum iterations allowed for fit
  104.      * @param threshold Convergence threshold computed as difference in
  105.      * logLikelihoods between successive iterations
  106.      * @throws MathIllegalArgumentException if any component's covariance matrix is
  107.      * singular during fitting
  108.      * @throws MathIllegalArgumentException if numComponents is less than one
  109.      * or threshold is less than Double.MIN_VALUE
  110.      * @throws MathIllegalArgumentException if initialMixture mean vector and data
  111.      * number of columns are not equal
  112.      */
  113.     public void fit(final MixtureMultivariateNormalDistribution initialMixture,
  114.                     final int maxIterations,
  115.                     final double threshold)
  116.         throws MathIllegalArgumentException {
  117.         if (maxIterations < 1) {
  118.             throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL,
  119.                                                    maxIterations, 1);
  120.         }

  121.         if (threshold < Double.MIN_VALUE) {
  122.             throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL,
  123.                                                    threshold, Double.MIN_VALUE);
  124.         }

  125.         final int n = data.length;

  126.         // Number of data columns. Jagged data already rejected in constructor,
  127.         // so we can assume the lengths of each row are equal.
  128.         final int numCols = data[0].length;
  129.         final int k = initialMixture.getComponents().size();

  130.         final int numMeanColumns
  131.             = initialMixture.getComponents().get(0).getSecond().getMeans().length;

  132.         if (numMeanColumns != numCols) {
  133.             throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH,
  134.                                                    numMeanColumns, numCols);
  135.         }

  136.         double previousLogLikelihood = 0d;

  137.         logLikelihood = Double.NEGATIVE_INFINITY;

  138.         // Initialize model to fit to initial mixture.
  139.         fittedModel = new MixtureMultivariateNormalDistribution(initialMixture.getComponents());

  140.         for (int numIterations = 0;
  141.              numIterations < maxIterations && FastMath.abs(previousLogLikelihood - logLikelihood) > threshold;
  142.              ++numIterations) {
  143.             previousLogLikelihood = logLikelihood;
  144.             double sumLogLikelihood = 0d;

  145.             // Mixture components
  146.             final List<Pair<Double, MultivariateNormalDistribution>> components
  147.                 = fittedModel.getComponents();

  148.             // Weight and distribution of each component
  149.             final double[] weights = new double[k];

  150.             final MultivariateNormalDistribution[] mvns = new MultivariateNormalDistribution[k];

  151.             for (int j = 0; j < k; j++) {
  152.                 weights[j] = components.get(j).getFirst();
  153.                 mvns[j] = components.get(j).getSecond();
  154.             }

  155.             // E-step: compute the data dependent parameters of the expectation
  156.             // function.
  157.             // The percentage of row's total density between a row and a
  158.             // component
  159.             final double[][] gamma = new double[n][k];

  160.             // Sum of gamma for each component
  161.             final double[] gammaSums = new double[k];

  162.             // Sum of gamma times its row for each each component
  163.             final double[][] gammaDataProdSums = new double[k][numCols];

  164.             for (int i = 0; i < n; i++) {
  165.                 final double rowDensity = fittedModel.density(data[i]);
  166.                 sumLogLikelihood += FastMath.log(rowDensity);

  167.                 for (int j = 0; j < k; j++) {
  168.                     gamma[i][j] = weights[j] * mvns[j].density(data[i]) / rowDensity;
  169.                     gammaSums[j] += gamma[i][j];

  170.                     for (int col = 0; col < numCols; col++) {
  171.                         gammaDataProdSums[j][col] += gamma[i][j] * data[i][col];
  172.                     }
  173.                 }
  174.             }

  175.             logLikelihood = sumLogLikelihood / n;

  176.             // M-step: compute the new parameters based on the expectation
  177.             // function.
  178.             final double[] newWeights = new double[k];
  179.             final double[][] newMeans = new double[k][numCols];

  180.             for (int j = 0; j < k; j++) {
  181.                 newWeights[j] = gammaSums[j] / n;
  182.                 for (int col = 0; col < numCols; col++) {
  183.                     newMeans[j][col] = gammaDataProdSums[j][col] / gammaSums[j];
  184.                 }
  185.             }

  186.             // Compute new covariance matrices
  187.             final RealMatrix[] newCovMats = new RealMatrix[k];
  188.             for (int j = 0; j < k; j++) {
  189.                 newCovMats[j] = new Array2DRowRealMatrix(numCols, numCols);
  190.             }
  191.             for (int i = 0; i < n; i++) {
  192.                 for (int j = 0; j < k; j++) {
  193.                     final RealMatrix vec
  194.                         = new Array2DRowRealMatrix(MathArrays.ebeSubtract(data[i], newMeans[j]));
  195.                     final RealMatrix dataCov
  196.                         = vec.multiplyTransposed(vec).scalarMultiply(gamma[i][j]);
  197.                     newCovMats[j] = newCovMats[j].add(dataCov);
  198.                 }
  199.             }

  200.             // Converting to arrays for use by fitted model
  201.             final double[][][] newCovMatArrays = new double[k][numCols][numCols];
  202.             for (int j = 0; j < k; j++) {
  203.                 newCovMats[j] = newCovMats[j].scalarMultiply(1d / gammaSums[j]);
  204.                 newCovMatArrays[j] = newCovMats[j].getData();
  205.             }

  206.             // Update current model
  207.             fittedModel = new MixtureMultivariateNormalDistribution(newWeights,
  208.                                                                     newMeans,
  209.                                                                     newCovMatArrays);
  210.         }

  211.         if (FastMath.abs(previousLogLikelihood - logLikelihood) > threshold) {
  212.             // Did not converge before the maximum number of iterations
  213.             throw new MathIllegalStateException(LocalizedCoreFormats.CONVERGENCE_FAILED);
  214.         }
  215.     }

  216.     /**
  217.      * Fit a mixture model to the data supplied to the constructor.
  218.      *
  219.      * The quality of the fit depends on the concavity of the data provided to
  220.      * the constructor and the initial mixture provided to this function. If the
  221.      * data has many local optima, multiple runs of the fitting function with
  222.      * different initial mixtures may be required to find the optimal solution.
  223.      * If a MathIllegalArgumentException is encountered, it is possible that another
  224.      * initialization would work.
  225.      *
  226.      * @param initialMixture Model containing initial values of weights and
  227.      * multivariate normals
  228.      * @throws MathIllegalArgumentException if any component's covariance matrix is
  229.      * singular during fitting
  230.      * @throws MathIllegalArgumentException if numComponents is less than one or
  231.      * threshold is less than Double.MIN_VALUE
  232.      */
  233.     public void fit(MixtureMultivariateNormalDistribution initialMixture)
  234.         throws MathIllegalArgumentException {
  235.         fit(initialMixture, DEFAULT_MAX_ITERATIONS, DEFAULT_THRESHOLD);
  236.     }

  237.     /**
  238.      * Helper method to create a multivariate normal mixture model which can be
  239.      * used to initialize {@link #fit(MixtureMultivariateNormalDistribution)}.
  240.      *
  241.      * This method uses the data supplied to the constructor to try to determine
  242.      * a good mixture model at which to start the fit, but it is not guaranteed
  243.      * to supply a model which will find the optimal solution or even converge.
  244.      *
  245.      * @param data Data to estimate distribution
  246.      * @param numComponents Number of components for estimated mixture
  247.      * @return Multivariate normal mixture model estimated from the data
  248.      * @throws MathIllegalArgumentException if {@code numComponents} is greater
  249.      * than the number of data rows.
  250.      * @throws MathIllegalArgumentException if {@code numComponents < 2}.
  251.      * @throws MathIllegalArgumentException if data has less than 2 rows
  252.      * @throws MathIllegalArgumentException if rows of data have different numbers
  253.      * of columns
  254.      */
  255.     public static MixtureMultivariateNormalDistribution estimate(final double[][] data,
  256.                                                                  final int numComponents)
  257.         throws MathIllegalArgumentException {
  258.         if (data.length < 2) {
  259.             throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL,
  260.                                                    data.length, 2);
  261.         }
  262.         if (numComponents < 2) {
  263.             throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL,
  264.                                                    numComponents, 2);
  265.         }
  266.         if (numComponents > data.length) {
  267.             throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_LARGE,
  268.                                                    numComponents, data.length);
  269.         }

  270.         final int numRows = data.length;
  271.         final int numCols = data[0].length;

  272.         // sort the data
  273.         final DataRow[] sortedData = new DataRow[numRows];
  274.         for (int i = 0; i < numRows; i++) {
  275.             sortedData[i] = new DataRow(data[i]);
  276.         }
  277.         Arrays.sort(sortedData);

  278.         // uniform weight for each bin
  279.         final double weight = 1d / numComponents;

  280.         // components of mixture model to be created
  281.         final List<Pair<Double, MultivariateNormalDistribution>> components = new ArrayList<>(numComponents);

  282.         // create a component based on data in each bin
  283.         for (int binIndex = 0; binIndex < numComponents; binIndex++) {
  284.             // minimum index (inclusive) from sorted data for this bin
  285.             final int minIndex = (binIndex * numRows) / numComponents;

  286.             // maximum index (exclusive) from sorted data for this bin
  287.             final int maxIndex = ((binIndex + 1) * numRows) / numComponents;

  288.             // number of data records that will be in this bin
  289.             final int numBinRows = maxIndex - minIndex;

  290.             // data for this bin
  291.             final double[][] binData = new double[numBinRows][numCols];

  292.             // mean of each column for the data in the this bin
  293.             final double[] columnMeans = new double[numCols];

  294.             // populate bin and create component
  295.             for (int i = minIndex; i < maxIndex; i++) {
  296.                 final int iBin = i - minIndex;
  297.                 for (int j = 0; j < numCols; j++) {
  298.                     final double val = sortedData[i].getRow()[j];
  299.                     columnMeans[j] += val;
  300.                     binData[iBin][j] = val;
  301.                 }
  302.             }

  303.             MathArrays.scaleInPlace(1d / numBinRows, columnMeans);

  304.             // covariance matrix for this bin
  305.             final double[][] covMat
  306.                 = new Covariance(binData).getCovarianceMatrix().getData();
  307.             final MultivariateNormalDistribution mvn
  308.                 = new MultivariateNormalDistribution(columnMeans, covMat);

  309.             components.add(new Pair<Double, MultivariateNormalDistribution>(weight, mvn));
  310.         }

  311.         return new MixtureMultivariateNormalDistribution(components);
  312.     }

  313.     /**
  314.      * Gets the log likelihood of the data under the fitted model.
  315.      *
  316.      * @return Log likelihood of data or zero of no data has been fit
  317.      */
  318.     public double getLogLikelihood() {
  319.         return logLikelihood;
  320.     }

  321.     /**
  322.      * Gets the fitted model.
  323.      *
  324.      * @return fitted model or {@code null} if no fit has been performed yet.
  325.      */
  326.     public MixtureMultivariateNormalDistribution getFittedModel() {
  327.         return new MixtureMultivariateNormalDistribution(fittedModel.getComponents());
  328.     }

  329.     /**
  330.      * Class used for sorting user-supplied data.
  331.      */
  332.     private static class DataRow implements Comparable<DataRow> {
  333.         /** One data row. */
  334.         private final double[] row;
  335.         /** Mean of the data row. */
  336.         private Double mean;

  337.         /**
  338.          * Create a data row.
  339.          * @param data Data to use for the row, a reference to the data is stored
  340.          */
  341.         DataRow(final double[] data) {
  342.             // Store reference.
  343.             row = data; // NOPMD - storing a reference to the array is intentional and documented here
  344.             // Compute mean.
  345.             mean = 0d;
  346.             for (int i = 0; i < data.length; i++) {
  347.                 mean += data[i];
  348.             }
  349.             mean /= data.length;
  350.         }

  351.         /**
  352.          * Compare two data rows.
  353.          * @param other The other row
  354.          * @return int for sorting
  355.          */
  356.         @Override
  357.         public int compareTo(final DataRow other) {
  358.             return mean.compareTo(other.mean);
  359.         }

  360.         /** {@inheritDoc} */
  361.         @Override
  362.         public boolean equals(Object other) {

  363.             if (this == other) {
  364.                 return true;
  365.             }

  366.             if (other instanceof DataRow) {
  367.                 return MathArrays.equals(row, ((DataRow) other).row);
  368.             }

  369.             return false;

  370.         }

  371.         /** {@inheritDoc} */
  372.         @Override
  373.         public int hashCode() {
  374.             return Arrays.hashCode(row);
  375.         }
  376.         /**
  377.          * Get a data row.
  378.          * @return data row array (a reference to the stored array is returned)
  379.          */
  380.         public double[] getRow() {
  381.             return row; // NOPMD - returning a reference to an internal array is documented here
  382.         }
  383.     }
  384. }