1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22 package org.hipparchus.stat.fitting;
23
24 import java.util.ArrayList;
25 import java.util.Arrays;
26 import java.util.List;
27
28 import org.hipparchus.distribution.multivariate.MixtureMultivariateNormalDistribution;
29 import org.hipparchus.distribution.multivariate.MultivariateNormalDistribution;
30 import org.hipparchus.exception.LocalizedCoreFormats;
31 import org.hipparchus.exception.MathIllegalArgumentException;
32 import org.hipparchus.exception.MathIllegalStateException;
33 import org.hipparchus.linear.Array2DRowRealMatrix;
34 import org.hipparchus.linear.RealMatrix;
35 import org.hipparchus.stat.correlation.Covariance;
36 import org.hipparchus.util.FastMath;
37 import org.hipparchus.util.MathArrays;
38 import org.hipparchus.util.Pair;
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54 public class MultivariateNormalMixtureExpectationMaximization {
55
56 private static final int DEFAULT_MAX_ITERATIONS = 1000;
57
58 private static final double DEFAULT_THRESHOLD = 1E-5;
59
60 private final double[][] data;
61
62 private MixtureMultivariateNormalDistribution fittedModel;
63
64 private double logLikelihood;
65
66
67
68
69
70
71
72
73
74
75
76 public MultivariateNormalMixtureExpectationMaximization(double[][] data)
77 throws MathIllegalArgumentException {
78 if (data.length < 1) {
79 throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL,
80 data.length, 1);
81 }
82
83 this.data = new double[data.length][data[0].length];
84
85 for (int i = 0; i < data.length; i++) {
86 if (data[i].length != data[0].length) {
87
88 throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH,
89 data[i].length, data[0].length);
90 }
91 if (data[i].length < 2) {
92 throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL,
93 data[i].length, 2, true);
94 }
95 this.data[i] = data[i].clone();
96 }
97 }
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121 public void fit(final MixtureMultivariateNormalDistribution initialMixture,
122 final int maxIterations,
123 final double threshold)
124 throws MathIllegalArgumentException {
125 if (maxIterations < 1) {
126 throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL,
127 maxIterations, 1);
128 }
129
130 if (threshold < Double.MIN_VALUE) {
131 throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL,
132 threshold, Double.MIN_VALUE);
133 }
134
135 final int n = data.length;
136
137
138
139 final int numCols = data[0].length;
140 final int k = initialMixture.getComponents().size();
141
142 final int numMeanColumns
143 = initialMixture.getComponents().get(0).getSecond().getMeans().length;
144
145 if (numMeanColumns != numCols) {
146 throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH,
147 numMeanColumns, numCols);
148 }
149
150 double previousLogLikelihood = 0d;
151
152 logLikelihood = Double.NEGATIVE_INFINITY;
153
154
155 fittedModel = new MixtureMultivariateNormalDistribution(initialMixture.getComponents());
156
157 for (int numIterations = 0;
158 numIterations < maxIterations && FastMath.abs(previousLogLikelihood - logLikelihood) > threshold;
159 ++numIterations) {
160 previousLogLikelihood = logLikelihood;
161 double sumLogLikelihood = 0d;
162
163
164 final List<Pair<Double, MultivariateNormalDistribution>> components
165 = fittedModel.getComponents();
166
167
168 final double[] weights = new double[k];
169
170 final MultivariateNormalDistribution[] mvns = new MultivariateNormalDistribution[k];
171
172 for (int j = 0; j < k; j++) {
173 weights[j] = components.get(j).getFirst();
174 mvns[j] = components.get(j).getSecond();
175 }
176
177
178
179
180
181 final double[][] gamma = new double[n][k];
182
183
184 final double[] gammaSums = new double[k];
185
186
187 final double[][] gammaDataProdSums = new double[k][numCols];
188
189 for (int i = 0; i < n; i++) {
190 final double rowDensity = fittedModel.density(data[i]);
191 sumLogLikelihood += FastMath.log(rowDensity);
192
193 for (int j = 0; j < k; j++) {
194 gamma[i][j] = weights[j] * mvns[j].density(data[i]) / rowDensity;
195 gammaSums[j] += gamma[i][j];
196
197 for (int col = 0; col < numCols; col++) {
198 gammaDataProdSums[j][col] += gamma[i][j] * data[i][col];
199 }
200 }
201 }
202
203 logLikelihood = sumLogLikelihood / n;
204
205
206
207 final double[] newWeights = new double[k];
208 final double[][] newMeans = new double[k][numCols];
209
210 for (int j = 0; j < k; j++) {
211 newWeights[j] = gammaSums[j] / n;
212 for (int col = 0; col < numCols; col++) {
213 newMeans[j][col] = gammaDataProdSums[j][col] / gammaSums[j];
214 }
215 }
216
217
218 final RealMatrix[] newCovMats = new RealMatrix[k];
219 for (int j = 0; j < k; j++) {
220 newCovMats[j] = new Array2DRowRealMatrix(numCols, numCols);
221 }
222 for (int i = 0; i < n; i++) {
223 for (int j = 0; j < k; j++) {
224 final RealMatrix vec
225 = new Array2DRowRealMatrix(MathArrays.ebeSubtract(data[i], newMeans[j]));
226 final RealMatrix dataCov
227 = vec.multiplyTransposed(vec).scalarMultiply(gamma[i][j]);
228 newCovMats[j] = newCovMats[j].add(dataCov);
229 }
230 }
231
232
233 final double[][][] newCovMatArrays = new double[k][numCols][numCols];
234 for (int j = 0; j < k; j++) {
235 newCovMats[j] = newCovMats[j].scalarMultiply(1d / gammaSums[j]);
236 newCovMatArrays[j] = newCovMats[j].getData();
237 }
238
239
240 fittedModel = new MixtureMultivariateNormalDistribution(newWeights,
241 newMeans,
242 newCovMatArrays);
243 }
244
245 if (FastMath.abs(previousLogLikelihood - logLikelihood) > threshold) {
246
247 throw new MathIllegalStateException(LocalizedCoreFormats.CONVERGENCE_FAILED);
248 }
249 }
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268 public void fit(MixtureMultivariateNormalDistribution initialMixture)
269 throws MathIllegalArgumentException {
270 fit(initialMixture, DEFAULT_MAX_ITERATIONS, DEFAULT_THRESHOLD);
271 }
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291 public static MixtureMultivariateNormalDistribution estimate(final double[][] data,
292 final int numComponents)
293 throws MathIllegalArgumentException {
294 if (data.length < 2) {
295 throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL,
296 data.length, 2);
297 }
298 if (numComponents < 2) {
299 throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL,
300 numComponents, 2);
301 }
302 if (numComponents > data.length) {
303 throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_LARGE,
304 numComponents, data.length);
305 }
306
307 final int numRows = data.length;
308 final int numCols = data[0].length;
309
310
311 final DataRow[] sortedData = new DataRow[numRows];
312 for (int i = 0; i < numRows; i++) {
313 sortedData[i] = new DataRow(data[i]);
314 }
315 Arrays.sort(sortedData);
316
317
318 final double weight = 1d / numComponents;
319
320
321 final List<Pair<Double, MultivariateNormalDistribution>> components = new ArrayList<>(numComponents);
322
323
324 for (int binIndex = 0; binIndex < numComponents; binIndex++) {
325
326 final int minIndex = (binIndex * numRows) / numComponents;
327
328
329 final int maxIndex = ((binIndex + 1) * numRows) / numComponents;
330
331
332 final int numBinRows = maxIndex - minIndex;
333
334
335 final double[][] binData = new double[numBinRows][numCols];
336
337
338 final double[] columnMeans = new double[numCols];
339
340
341 for (int i = minIndex; i < maxIndex; i++) {
342 final int iBin = i - minIndex;
343 for (int j = 0; j < numCols; j++) {
344 final double val = sortedData[i].getRow()[j];
345 columnMeans[j] += val;
346 binData[iBin][j] = val;
347 }
348 }
349
350 MathArrays.scaleInPlace(1d / numBinRows, columnMeans);
351
352
353 final double[][] covMat
354 = new Covariance(binData).getCovarianceMatrix().getData();
355 final MultivariateNormalDistribution mvn
356 = new MultivariateNormalDistribution(columnMeans, covMat);
357
358 components.add(new Pair<Double, MultivariateNormalDistribution>(weight, mvn));
359 }
360
361 return new MixtureMultivariateNormalDistribution(components);
362 }
363
364
365
366
367
368
369 public double getLogLikelihood() {
370 return logLikelihood;
371 }
372
373
374
375
376
377
378 public MixtureMultivariateNormalDistribution getFittedModel() {
379 return new MixtureMultivariateNormalDistribution(fittedModel.getComponents());
380 }
381
382
383
384
385 private static class DataRow implements Comparable<DataRow> {
386
387 private final double[] row;
388
389 private Double mean;
390
391
392
393
394
395 DataRow(final double[] data) {
396
397 row = data;
398
399 mean = 0d;
400 for (int i = 0; i < data.length; i++) {
401 mean += data[i];
402 }
403 mean /= data.length;
404 }
405
406
407
408
409
410
411 @Override
412 public int compareTo(final DataRow other) {
413 return mean.compareTo(other.mean);
414 }
415
416
417 @Override
418 public boolean equals(Object other) {
419
420 if (this == other) {
421 return true;
422 }
423
424 if (other instanceof DataRow) {
425 return MathArrays.equals(row, ((DataRow) other).row);
426 }
427
428 return false;
429
430 }
431
432
433 @Override
434 public int hashCode() {
435 return Arrays.hashCode(row);
436 }
437
438
439
440
441 public double[] getRow() {
442 return row;
443 }
444 }
445 }
446