1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.hipparchus.stat.fitting;
18
19 import java.lang.reflect.Constructor;
20 import java.lang.reflect.InvocationTargetException;
21 import java.util.ArrayList;
22 import java.util.Arrays;
23 import java.util.List;
24
25 import org.hipparchus.distribution.multivariate.MixtureMultivariateNormalDistribution;
26 import org.hipparchus.distribution.multivariate.MultivariateNormalDistribution;
27 import org.hipparchus.exception.LocalizedCoreFormats;
28 import org.hipparchus.exception.MathIllegalArgumentException;
29 import org.hipparchus.exception.MathIllegalStateException;
30 import org.hipparchus.linear.Array2DRowRealMatrix;
31 import org.hipparchus.linear.RealMatrix;
32 import org.hipparchus.util.Pair;
33 import org.junit.Assert;
34 import org.junit.Test;
35
36
37
38
39
40 public class MultivariateNormalMixtureExpectationMaximizationTest {
41
42 @Test(expected = MathIllegalArgumentException.class)
43 public void testNonEmptyData() {
44
45 new MultivariateNormalMixtureExpectationMaximization(new double[][] {});
46 }
47
48 @Test(expected = MathIllegalArgumentException.class)
49 public void testNonJaggedData() {
50
51 double[][] data = new double[][] {
52 { 1, 2, 3 },
53 { 4, 5, 6, 7 },
54 };
55 new MultivariateNormalMixtureExpectationMaximization(data);
56 }
57
58 @Test(expected = MathIllegalArgumentException.class)
59 public void testMultipleColumnsRequired() {
60
61 double[][] data = new double[][] {
62 { 1 }, { 2 }
63 };
64 new MultivariateNormalMixtureExpectationMaximization(data);
65 }
66
67 @Test(expected = MathIllegalArgumentException.class)
68 public void testMaxIterationsPositive() {
69
70 double[][] data = getTestSamples();
71 MultivariateNormalMixtureExpectationMaximization fitter =
72 new MultivariateNormalMixtureExpectationMaximization(data);
73
74 MixtureMultivariateNormalDistribution
75 initialMix = MultivariateNormalMixtureExpectationMaximization.estimate(data, 2);
76
77 fitter.fit(initialMix, 0, 1E-5);
78 }
79
80 @Test(expected = MathIllegalArgumentException.class)
81 public void testThresholdPositive() {
82
83 double[][] data = getTestSamples();
84 MultivariateNormalMixtureExpectationMaximization fitter =
85 new MultivariateNormalMixtureExpectationMaximization(
86 data);
87
88 MixtureMultivariateNormalDistribution
89 initialMix = MultivariateNormalMixtureExpectationMaximization.estimate(data, 2);
90
91 fitter.fit(initialMix, 1000, 0);
92 }
93
94 @Test(expected = MathIllegalStateException.class)
95 public void testConvergenceException() {
96
97 double[][] data = getTestSamples();
98 MultivariateNormalMixtureExpectationMaximization fitter
99 = new MultivariateNormalMixtureExpectationMaximization(data);
100
101 MixtureMultivariateNormalDistribution
102 initialMix = MultivariateNormalMixtureExpectationMaximization.estimate(data, 2);
103
104
105 fitter.fit(initialMix, 5, 1E-5);
106 }
107
108 @Test(expected = MathIllegalArgumentException.class)
109 public void testIncompatibleIntialMixture() {
110
111 double[][] data = new double[][] {
112 { 1, 2, 3 }, { 4, 5, 6 }, { 7, 8, 9 }
113 };
114 double[] weights = new double[] { 0.5, 0.5 };
115
116
117
118 MultivariateNormalDistribution[] mvns = new MultivariateNormalDistribution[2];
119
120 mvns[0] = new MultivariateNormalDistribution(new double[] {
121 -0.0021722935000328823, 3.5432892936887908 },
122 new double[][] {
123 { 4.537422569229048, 3.5266152281729304 },
124 { 3.5266152281729304, 6.175448814169779 } });
125 mvns[1] = new MultivariateNormalDistribution(new double[] {
126 5.090902706507635, 8.68540656355283 }, new double[][] {
127 { 2.886778573963039, 1.5257474543463154 },
128 { 1.5257474543463154, 3.3794567673616918 } });
129
130
131 List<Pair<Double, MultivariateNormalDistribution>> components =
132 new ArrayList<Pair<Double, MultivariateNormalDistribution>>();
133 components.add(new Pair<Double, MultivariateNormalDistribution>(
134 weights[0], mvns[0]));
135 components.add(new Pair<Double, MultivariateNormalDistribution>(
136 weights[1], mvns[1]));
137
138 MixtureMultivariateNormalDistribution badInitialMix
139 = new MixtureMultivariateNormalDistribution(components);
140
141 MultivariateNormalMixtureExpectationMaximization fitter
142 = new MultivariateNormalMixtureExpectationMaximization(data);
143
144 fitter.fit(badInitialMix);
145 }
146
147 @Test
148 public void testInitialMixture() {
149
150 final double[] correctWeights = new double[] { 0.5, 0.5 };
151
152 final double[][] correctMeans = new double[][] {
153 {-0.0021722935000328823, 3.5432892936887908},
154 {5.090902706507635, 8.68540656355283},
155 };
156
157 final RealMatrix[] correctCovMats = new Array2DRowRealMatrix[2];
158
159 correctCovMats[0] = new Array2DRowRealMatrix(new double[][] {
160 { 4.537422569229048, 3.5266152281729304 },
161 { 3.5266152281729304, 6.175448814169779 } });
162
163 correctCovMats[1] = new Array2DRowRealMatrix( new double[][] {
164 { 2.886778573963039, 1.5257474543463154 },
165 { 1.5257474543463154, 3.3794567673616918 } });
166
167 final MultivariateNormalDistribution[] correctMVNs = new
168 MultivariateNormalDistribution[2];
169
170 correctMVNs[0] = new MultivariateNormalDistribution(correctMeans[0],
171 correctCovMats[0].getData());
172
173 correctMVNs[1] = new MultivariateNormalDistribution(correctMeans[1],
174 correctCovMats[1].getData());
175
176 final MixtureMultivariateNormalDistribution initialMix
177 = MultivariateNormalMixtureExpectationMaximization.estimate(getTestSamples(), 2);
178
179 int i = 0;
180 for (Pair<Double, MultivariateNormalDistribution> component : initialMix
181 .getComponents()) {
182 Assert.assertEquals(correctWeights[i], component.getFirst(),
183 Math.ulp(1d));
184
185 final double[] means = component.getValue().getMeans();
186 Assert.assertTrue(Arrays.equals(correctMeans[i], means));
187
188 final RealMatrix covMat = component.getValue().getCovariances();
189 Assert.assertEquals(correctCovMats[i], covMat);
190 i++;
191 }
192 }
193
194 @Test
195 public void testWrongData() {
196 checkWrongData(new double[1][1], 2, LocalizedCoreFormats.NUMBER_TOO_SMALL);
197 checkWrongData(new double[3][3], 1, LocalizedCoreFormats.NUMBER_TOO_SMALL);
198 checkWrongData(new double[3][3], 4, LocalizedCoreFormats.NUMBER_TOO_LARGE);
199 }
200
201 private void checkWrongData(final double[][] data, final int numComponents,
202 final LocalizedCoreFormats expected) {
203 try {
204 MultivariateNormalMixtureExpectationMaximization.estimate(data, numComponents);
205 Assert.fail("an exception should have been thrown");
206 } catch (MathIllegalArgumentException miae) {
207 Assert.assertEquals(expected, miae.getSpecifier());
208 }
209 }
210
211 @Test
212 public void testUnusedInheritedMethods() {
213
214 try {
215 Class<?> dataRowClass = MultivariateNormalMixtureExpectationMaximization.class.getDeclaredClasses()[0];
216 Constructor<?> dataRowConstructor = dataRowClass.getDeclaredConstructor(double[].class);
217 dataRowConstructor.setAccessible(true);
218 Object dr1 = dataRowConstructor.newInstance(new double[] { 1, 2, 3 });
219 Assert.assertEquals(66614367, dr1.hashCode());
220 Assert.assertTrue(dr1.equals(dr1));
221 Assert.assertFalse(dr1.equals(""));
222 Assert.assertTrue(dr1.equals(dataRowConstructor.newInstance(new double[] { 1, 2, 3 })));
223 Assert.assertFalse(dr1.equals(dataRowConstructor.newInstance(new double[] { 3, 2, 1 })));
224 } catch (InvocationTargetException | NoSuchMethodException | SecurityException |
225 InstantiationException | IllegalAccessException | IllegalArgumentException e) {
226 Assert.fail(e.getLocalizedMessage());
227 }
228 }
229
230 @Test
231 public void testFit() {
232
233
234 final double[][] data = getTestSamples();
235 final double correctLogLikelihood = -4.292431006791994;
236 final double[] correctWeights = new double[] { 0.2962324189652912, 0.7037675810347089 };
237
238 final double[][] correctMeans = new double[][]{
239 {-1.4213112715121132, 1.6924690505757753},
240 {4.213612224374709, 7.975621325853645}
241 };
242
243 final RealMatrix[] correctCovMats = new Array2DRowRealMatrix[2];
244 correctCovMats[0] = new Array2DRowRealMatrix(new double[][] {
245 { 1.739356907285747, -0.5867644251487614 },
246 { -0.5867644251487614, 1.0232932029324642 } }
247 );
248 correctCovMats[1] = new Array2DRowRealMatrix(new double[][] {
249 { 4.245384898007161, 2.5797798966382155 },
250 { 2.5797798966382155, 3.9200272522448367 } });
251
252 final MultivariateNormalDistribution[] correctMVNs = new MultivariateNormalDistribution[2];
253 correctMVNs[0] = new MultivariateNormalDistribution(correctMeans[0], correctCovMats[0].getData());
254 correctMVNs[1] = new MultivariateNormalDistribution(correctMeans[1], correctCovMats[1].getData());
255
256 MultivariateNormalMixtureExpectationMaximization fitter
257 = new MultivariateNormalMixtureExpectationMaximization(data);
258
259 MixtureMultivariateNormalDistribution initialMix
260 = MultivariateNormalMixtureExpectationMaximization.estimate(data, 2);
261 fitter.fit(initialMix);
262 MixtureMultivariateNormalDistribution fittedMix = fitter.getFittedModel();
263 List<Pair<Double, MultivariateNormalDistribution>> components = fittedMix.getComponents();
264
265 Assert.assertEquals(correctLogLikelihood,
266 fitter.getLogLikelihood(),
267 Math.ulp(1d));
268
269 int i = 0;
270 for (Pair<Double, MultivariateNormalDistribution> component : components) {
271 final double weight = component.getFirst();
272 final MultivariateNormalDistribution mvn = component.getSecond();
273 final double[] mean = mvn.getMeans();
274 final RealMatrix covMat = mvn.getCovariances();
275 Assert.assertEquals(correctWeights[i], weight, Math.ulp(1d));
276 Assert.assertTrue(Arrays.equals(correctMeans[i], mean));
277 Assert.assertEquals(correctCovMats[i], covMat);
278 i++;
279 }
280 }
281
282 private double[][] getTestSamples() {
283
284
285 return new double[][] { { 7.358553610469948, 11.31260831446758 },
286 { 7.175770420124739, 8.988812210204454 },
287 { 4.324151905768422, 6.837727899051482 },
288 { 2.157832219173036, 6.317444585521968 },
289 { -1.890157421896651, 1.74271202875498 },
290 { 0.8922409354455803, 1.999119343923781 },
291 { 3.396949764787055, 6.813170372579068 },
292 { -2.057498232686068, -0.002522983830852255 },
293 { 6.359932157365045, 8.343600029975851 },
294 { 3.353102234276168, 7.087541882898689 },
295 { -1.763877221595639, 0.9688890460330644 },
296 { 6.151457185125111, 9.075011757431174 },
297 { 4.281597398048899, 5.953270070976117 },
298 { 3.549576703974894, 8.616038155992861 },
299 { 6.004706732349854, 8.959423391087469 },
300 { 2.802915014676262, 6.285676742173564 },
301 { -0.6029879029880616, 1.083332958357485 },
302 { 3.631827105398369, 6.743428504049444 },
303 { 6.161125014007315, 9.60920569689001 },
304 { -1.049582894255342, 0.2020017892080281 },
305 { 3.910573022688315, 8.19609909534937 },
306 { 8.180454017634863, 7.861055769719962 },
307 { 1.488945440439716, 8.02699903761247 },
308 { 4.813750847823778, 12.34416881332515 },
309 { 0.0443208501259158, 5.901148093240691 },
310 { 4.416417235068346, 4.465243084006094 },
311 { 4.0002433603072, 6.721937850166174 },
312 { 3.190113818788205, 10.51648348411058 },
313 { 4.493600914967883, 7.938224231022314 },
314 { -3.675669533266189, 4.472845076673303 },
315 { 6.648645511703989, 12.03544085965724 },
316 { -1.330031331404445, 1.33931042964811 },
317 { -3.812111460708707, 2.50534195568356 },
318 { 5.669339356648331, 6.214488981177026 },
319 { 1.006596727153816, 1.51165463112716 },
320 { 5.039466365033024, 7.476532610478689 },
321 { 4.349091929968925, 7.446356406259756 },
322 { -1.220289665119069, 3.403926955951437 },
323 { 5.553003979122395, 6.886518211202239 },
324 { 2.274487732222856, 7.009541508533196 },
325 { 4.147567059965864, 7.34025244349202 },
326 { 4.083882618965819, 6.362852861075623 },
327 { 2.203122344647599, 7.260295257904624 },
328 { -2.147497550770442, 1.262293431529498 },
329 { 2.473700950426512, 6.558900135505638 },
330 { 8.267081298847554, 12.10214104577748 },
331 { 6.91977329776865, 9.91998488301285 },
332 { 0.1680479852730894, 6.28286034168897 },
333 { -1.268578659195158, 2.326711221485755 },
334 { 1.829966451374701, 6.254187605304518 },
335 { 5.648849025754848, 9.330002040750291 },
336 { -2.302874793257666, 3.585545172776065 },
337 { -2.629218791709046, 2.156215538500288 },
338 { 4.036618140700114, 10.2962785719958 },
339 { 0.4616386422783874, 0.6782756325806778 },
340 { -0.3447896073408363, 0.4999834691645118 },
341 { -0.475281453118318, 1.931470384180492 },
342 { 2.382509690609731, 6.071782429815853 },
343 { -3.203934441889096, 2.572079552602468 },
344 { 8.465636032165087, 13.96462998683518 },
345 { 2.36755660870416, 5.7844595007273 },
346 { 0.5935496528993371, 1.374615871358943 },
347 { -2.467481505748694, 2.097224634713005 },
348 { 4.27867444328542, 10.24772361238549 },
349 { -2.013791907543137, 2.013799426047639 },
350 { 6.424588084404173, 9.185334939684516 },
351 { -0.8448238876802175, 0.5447382022282812 },
352 { 1.342955703473923, 8.645456317633556 },
353 { 3.108712208751979, 8.512156853800064 },
354 { 4.343205178315472, 8.056869549234374 },
355 { -2.971767642212396, 3.201180146824761 },
356 { 2.583820931523672, 5.459873414473854 },
357 { 4.209139115268925, 8.171098193546225 },
358 { 0.4064909057902746, 1.454390775518743 },
359 { 3.068642411145223, 6.959485153620035 },
360 { 6.085968972900461, 7.391429799500965 },
361 { -1.342265795764202, 1.454550012997143 },
362 { 6.249773274516883, 6.290269880772023 },
363 { 4.986225847822566, 7.75266344868907 },
364 { 7.642443254378944, 10.19914817500263 },
365 { 6.438181159163673, 8.464396764810347 },
366 { 2.520859761025108, 7.68222425260111 },
367 { 2.883699944257541, 6.777960331348503 },
368 { 2.788004550956599, 6.634735386652733 },
369 { 3.331661231995638, 5.794191300046592 },
370 { 3.526172276645504, 6.710802266815884 },
371 { 3.188298528138741, 10.34495528210205 },
372 { 0.7345539486114623, 5.807604004180681 },
373 { 1.165044595880125, 7.830121829295257 },
374 { 7.146962523500671, 11.62995162065415 },
375 { 7.813872137162087, 10.62827008714735 },
376 { 3.118099164870063, 8.286003148186371 },
377 { -1.708739286262571, 1.561026755374264 },
378 { 1.786163047580084, 4.172394388214604 },
379 { 3.718506403232386, 7.807752990130349 },
380 { 6.167414046828899, 10.01104941031293 },
381 { -1.063477247689196, 1.61176085846339 },
382 { -3.396739609433642, 0.7127911050002151 },
383 { 2.438885945896797, 7.353011138689225 },
384 { -0.2073204144780931, 0.850771146627012 }, };
385 }
386 }