View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  /*
19   * This is not the original file distributed by the Apache Software Foundation
20   * It has been modified by the Hipparchus project
21   */
22  package org.hipparchus.stat.fitting;
23  
24  import java.util.ArrayList;
25  import java.util.Arrays;
26  import java.util.List;
27  
28  import org.hipparchus.distribution.multivariate.MixtureMultivariateNormalDistribution;
29  import org.hipparchus.distribution.multivariate.MultivariateNormalDistribution;
30  import org.hipparchus.exception.LocalizedCoreFormats;
31  import org.hipparchus.exception.MathIllegalArgumentException;
32  import org.hipparchus.exception.MathIllegalStateException;
33  import org.hipparchus.linear.Array2DRowRealMatrix;
34  import org.hipparchus.linear.RealMatrix;
35  import org.hipparchus.stat.correlation.Covariance;
36  import org.hipparchus.util.FastMath;
37  import org.hipparchus.util.MathArrays;
38  import org.hipparchus.util.Pair;
39  
40  /**
41   * Expectation-Maximization algorithm for fitting the parameters of
42   * multivariate normal mixture model distributions.
43   *
44   * This implementation is pure original code based on <a
45   * href="https://www.ee.washington.edu/techsite/papers/documents/UWEETR-2010-0002.pdf">
46   * EM Demystified: An Expectation-Maximization Tutorial</a> by Yihua Chen and Maya R. Gupta,
47   * Department of Electrical Engineering, University of Washington, Seattle, WA 98195.
48   * It was verified using external tools like <a
49   * href="http://cran.r-project.org/web/packages/mixtools/index.html">CRAN Mixtools</a>
50   * (see the JUnit test cases) but it is <strong>not</strong> based on Mixtools code at all.
51   * The discussion of the origin of this class can be seen in the comments of the <a
52   * href="https://issues.apache.org/jira/browse/MATH-817">MATH-817</a> JIRA issue.
53   */
54  public class MultivariateNormalMixtureExpectationMaximization {
55      /** Default maximum number of iterations allowed per fitting process. */
56      private static final int DEFAULT_MAX_ITERATIONS = 1000;
57      /** Default convergence threshold for fitting. */
58      private static final double DEFAULT_THRESHOLD = 1E-5;
59      /** The data to fit. */
60      private final double[][] data;
61      /** The model fit against the data. */
62      private MixtureMultivariateNormalDistribution fittedModel;
63      /** The log likelihood of the data given the fitted model. */
64      private double logLikelihood;
65  
66      /**
67       * Creates an object to fit a multivariate normal mixture model to data.
68       *
69       * @param data Data to use in fitting procedure
70       * @throws MathIllegalArgumentException if data has no rows
71       * @throws MathIllegalArgumentException if rows of data have different numbers
72       * of columns
73       * @throws MathIllegalArgumentException if the number of columns in the data is
74       * less than 2
75       */
76      public MultivariateNormalMixtureExpectationMaximization(double[][] data)
77          throws MathIllegalArgumentException {
78          if (data.length < 1) {
79              throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL,
80                                                     data.length, 1);
81          }
82  
83          this.data = new double[data.length][data[0].length];
84  
85          for (int i = 0; i < data.length; i++) {
86              if (data[i].length != data[0].length) {
87                  // Jagged arrays not allowed
88                  throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH,
89                                                         data[i].length, data[0].length);
90              }
91              if (data[i].length < 2) {
92                  throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL,
93                                                      data[i].length, 2, true);
94              }
95              this.data[i] = data[i].clone();
96          }
97      }
98  
99      /**
100      * Fit a mixture model to the data supplied to the constructor.
101      *
102      * The quality of the fit depends on the concavity of the data provided to
103      * the constructor and the initial mixture provided to this function. If the
104      * data has many local optima, multiple runs of the fitting function with
105      * different initial mixtures may be required to find the optimal solution.
106      * If a MathIllegalArgumentException is encountered, it is possible that another
107      * initialization would work.
108      *
109      * @param initialMixture Model containing initial values of weights and
110      * multivariate normals
111      * @param maxIterations Maximum iterations allowed for fit
112      * @param threshold Convergence threshold computed as difference in
113      * logLikelihoods between successive iterations
114      * @throws MathIllegalArgumentException if any component's covariance matrix is
115      * singular during fitting
116      * @throws MathIllegalArgumentException if numComponents is less than one
117      * or threshold is less than Double.MIN_VALUE
118      * @throws MathIllegalArgumentException if initialMixture mean vector and data
119      * number of columns are not equal
120      */
121     public void fit(final MixtureMultivariateNormalDistribution initialMixture,
122                     final int maxIterations,
123                     final double threshold)
124         throws MathIllegalArgumentException {
125         if (maxIterations < 1) {
126             throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL,
127                                                    maxIterations, 1);
128         }
129 
130         if (threshold < Double.MIN_VALUE) {
131             throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL,
132                                                    threshold, Double.MIN_VALUE);
133         }
134 
135         final int n = data.length;
136 
137         // Number of data columns. Jagged data already rejected in constructor,
138         // so we can assume the lengths of each row are equal.
139         final int numCols = data[0].length;
140         final int k = initialMixture.getComponents().size();
141 
142         final int numMeanColumns
143             = initialMixture.getComponents().get(0).getSecond().getMeans().length;
144 
145         if (numMeanColumns != numCols) {
146             throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH,
147                                                    numMeanColumns, numCols);
148         }
149 
150         double previousLogLikelihood = 0d;
151 
152         logLikelihood = Double.NEGATIVE_INFINITY;
153 
154         // Initialize model to fit to initial mixture.
155         fittedModel = new MixtureMultivariateNormalDistribution(initialMixture.getComponents());
156 
157         for (int numIterations = 0;
158              numIterations < maxIterations && FastMath.abs(previousLogLikelihood - logLikelihood) > threshold;
159              ++numIterations) {
160             previousLogLikelihood = logLikelihood;
161             double sumLogLikelihood = 0d;
162 
163             // Mixture components
164             final List<Pair<Double, MultivariateNormalDistribution>> components
165                 = fittedModel.getComponents();
166 
167             // Weight and distribution of each component
168             final double[] weights = new double[k];
169 
170             final MultivariateNormalDistribution[] mvns = new MultivariateNormalDistribution[k];
171 
172             for (int j = 0; j < k; j++) {
173                 weights[j] = components.get(j).getFirst();
174                 mvns[j] = components.get(j).getSecond();
175             }
176 
177             // E-step: compute the data dependent parameters of the expectation
178             // function.
179             // The percentage of row's total density between a row and a
180             // component
181             final double[][] gamma = new double[n][k];
182 
183             // Sum of gamma for each component
184             final double[] gammaSums = new double[k];
185 
186             // Sum of gamma times its row for each each component
187             final double[][] gammaDataProdSums = new double[k][numCols];
188 
189             for (int i = 0; i < n; i++) {
190                 final double rowDensity = fittedModel.density(data[i]);
191                 sumLogLikelihood += FastMath.log(rowDensity);
192 
193                 for (int j = 0; j < k; j++) {
194                     gamma[i][j] = weights[j] * mvns[j].density(data[i]) / rowDensity;
195                     gammaSums[j] += gamma[i][j];
196 
197                     for (int col = 0; col < numCols; col++) {
198                         gammaDataProdSums[j][col] += gamma[i][j] * data[i][col];
199                     }
200                 }
201             }
202 
203             logLikelihood = sumLogLikelihood / n;
204 
205             // M-step: compute the new parameters based on the expectation
206             // function.
207             final double[] newWeights = new double[k];
208             final double[][] newMeans = new double[k][numCols];
209 
210             for (int j = 0; j < k; j++) {
211                 newWeights[j] = gammaSums[j] / n;
212                 for (int col = 0; col < numCols; col++) {
213                     newMeans[j][col] = gammaDataProdSums[j][col] / gammaSums[j];
214                 }
215             }
216 
217             // Compute new covariance matrices
218             final RealMatrix[] newCovMats = new RealMatrix[k];
219             for (int j = 0; j < k; j++) {
220                 newCovMats[j] = new Array2DRowRealMatrix(numCols, numCols);
221             }
222             for (int i = 0; i < n; i++) {
223                 for (int j = 0; j < k; j++) {
224                     final RealMatrix vec
225                         = new Array2DRowRealMatrix(MathArrays.ebeSubtract(data[i], newMeans[j]));
226                     final RealMatrix dataCov
227                         = vec.multiplyTransposed(vec).scalarMultiply(gamma[i][j]);
228                     newCovMats[j] = newCovMats[j].add(dataCov);
229                 }
230             }
231 
232             // Converting to arrays for use by fitted model
233             final double[][][] newCovMatArrays = new double[k][numCols][numCols];
234             for (int j = 0; j < k; j++) {
235                 newCovMats[j] = newCovMats[j].scalarMultiply(1d / gammaSums[j]);
236                 newCovMatArrays[j] = newCovMats[j].getData();
237             }
238 
239             // Update current model
240             fittedModel = new MixtureMultivariateNormalDistribution(newWeights,
241                                                                     newMeans,
242                                                                     newCovMatArrays);
243         }
244 
245         if (FastMath.abs(previousLogLikelihood - logLikelihood) > threshold) {
246             // Did not converge before the maximum number of iterations
247             throw new MathIllegalStateException(LocalizedCoreFormats.CONVERGENCE_FAILED);
248         }
249     }
250 
251     /**
252      * Fit a mixture model to the data supplied to the constructor.
253      *
254      * The quality of the fit depends on the concavity of the data provided to
255      * the constructor and the initial mixture provided to this function. If the
256      * data has many local optima, multiple runs of the fitting function with
257      * different initial mixtures may be required to find the optimal solution.
258      * If a MathIllegalArgumentException is encountered, it is possible that another
259      * initialization would work.
260      *
261      * @param initialMixture Model containing initial values of weights and
262      * multivariate normals
263      * @throws MathIllegalArgumentException if any component's covariance matrix is
264      * singular during fitting
265      * @throws MathIllegalArgumentException if numComponents is less than one or
266      * threshold is less than Double.MIN_VALUE
267      */
268     public void fit(MixtureMultivariateNormalDistribution initialMixture)
269         throws MathIllegalArgumentException {
270         fit(initialMixture, DEFAULT_MAX_ITERATIONS, DEFAULT_THRESHOLD);
271     }
272 
273     /**
274      * Helper method to create a multivariate normal mixture model which can be
275      * used to initialize {@link #fit(MixtureMultivariateNormalDistribution)}.
276      *
277      * This method uses the data supplied to the constructor to try to determine
278      * a good mixture model at which to start the fit, but it is not guaranteed
279      * to supply a model which will find the optimal solution or even converge.
280      *
281      * @param data Data to estimate distribution
282      * @param numComponents Number of components for estimated mixture
283      * @return Multivariate normal mixture model estimated from the data
284      * @throws MathIllegalArgumentException if {@code numComponents} is greater
285      * than the number of data rows.
286      * @throws MathIllegalArgumentException if {@code numComponents < 2}.
287      * @throws MathIllegalArgumentException if data has less than 2 rows
288      * @throws MathIllegalArgumentException if rows of data have different numbers
289      * of columns
290      */
291     public static MixtureMultivariateNormalDistribution estimate(final double[][] data,
292                                                                  final int numComponents)
293         throws MathIllegalArgumentException {
294         if (data.length < 2) {
295             throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL,
296                                                    data.length, 2);
297         }
298         if (numComponents < 2) {
299             throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL,
300                                                    numComponents, 2);
301         }
302         if (numComponents > data.length) {
303             throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_LARGE,
304                                                    numComponents, data.length);
305         }
306 
307         final int numRows = data.length;
308         final int numCols = data[0].length;
309 
310         // sort the data
311         final DataRow[] sortedData = new DataRow[numRows];
312         for (int i = 0; i < numRows; i++) {
313             sortedData[i] = new DataRow(data[i]);
314         }
315         Arrays.sort(sortedData);
316 
317         // uniform weight for each bin
318         final double weight = 1d / numComponents;
319 
320         // components of mixture model to be created
321         final List<Pair<Double, MultivariateNormalDistribution>> components = new ArrayList<>(numComponents);
322 
323         // create a component based on data in each bin
324         for (int binIndex = 0; binIndex < numComponents; binIndex++) {
325             // minimum index (inclusive) from sorted data for this bin
326             final int minIndex = (binIndex * numRows) / numComponents;
327 
328             // maximum index (exclusive) from sorted data for this bin
329             final int maxIndex = ((binIndex + 1) * numRows) / numComponents;
330 
331             // number of data records that will be in this bin
332             final int numBinRows = maxIndex - minIndex;
333 
334             // data for this bin
335             final double[][] binData = new double[numBinRows][numCols];
336 
337             // mean of each column for the data in the this bin
338             final double[] columnMeans = new double[numCols];
339 
340             // populate bin and create component
341             for (int i = minIndex; i < maxIndex; i++) {
342                 final int iBin = i - minIndex;
343                 for (int j = 0; j < numCols; j++) {
344                     final double val = sortedData[i].getRow()[j];
345                     columnMeans[j] += val;
346                     binData[iBin][j] = val;
347                 }
348             }
349 
350             MathArrays.scaleInPlace(1d / numBinRows, columnMeans);
351 
352             // covariance matrix for this bin
353             final double[][] covMat
354                 = new Covariance(binData).getCovarianceMatrix().getData();
355             final MultivariateNormalDistribution mvn
356                 = new MultivariateNormalDistribution(columnMeans, covMat);
357 
358             components.add(new Pair<Double, MultivariateNormalDistribution>(weight, mvn));
359         }
360 
361         return new MixtureMultivariateNormalDistribution(components);
362     }
363 
364     /**
365      * Gets the log likelihood of the data under the fitted model.
366      *
367      * @return Log likelihood of data or zero of no data has been fit
368      */
369     public double getLogLikelihood() {
370         return logLikelihood;
371     }
372 
373     /**
374      * Gets the fitted model.
375      *
376      * @return fitted model or {@code null} if no fit has been performed yet.
377      */
378     public MixtureMultivariateNormalDistribution getFittedModel() {
379         return new MixtureMultivariateNormalDistribution(fittedModel.getComponents());
380     }
381 
382     /**
383      * Class used for sorting user-supplied data.
384      */
385     private static class DataRow implements Comparable<DataRow> {
386         /** One data row. */
387         private final double[] row;
388         /** Mean of the data row. */
389         private Double mean;
390 
391         /**
392          * Create a data row.
393          * @param data Data to use for the row, a reference to the data is stored
394          */
395         DataRow(final double[] data) {
396             // Store reference.
397             row = data; // NOPMD - storing a reference to the array is intentional and documented here
398             // Compute mean.
399             mean = 0d;
400             for (int i = 0; i < data.length; i++) {
401                 mean += data[i];
402             }
403             mean /= data.length;
404         }
405 
406         /**
407          * Compare two data rows.
408          * @param other The other row
409          * @return int for sorting
410          */
411         @Override
412         public int compareTo(final DataRow other) {
413             return mean.compareTo(other.mean);
414         }
415 
416         /** {@inheritDoc} */
417         @Override
418         public boolean equals(Object other) {
419 
420             if (this == other) {
421                 return true;
422             }
423 
424             if (other instanceof DataRow) {
425                 return MathArrays.equals(row, ((DataRow) other).row);
426             }
427 
428             return false;
429 
430         }
431 
432         /** {@inheritDoc} */
433         @Override
434         public int hashCode() {
435             return Arrays.hashCode(row);
436         }
437         /**
438          * Get a data row.
439          * @return data row array (a reference to the stored array is returned)
440          */
441         public double[] getRow() {
442             return row; // NOPMD - returning a reference to an internal array is documented here
443         }
444     }
445 }
446