1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23 package org.hipparchus.clustering;
24
25 import java.util.ArrayList;
26 import java.util.Arrays;
27 import java.util.Collection;
28 import java.util.List;
29
30 import org.hipparchus.clustering.distance.EuclideanDistance;
31 import org.hipparchus.exception.MathIllegalArgumentException;
32 import org.hipparchus.random.JDKRandomGenerator;
33 import org.hipparchus.random.RandomGenerator;
34 import org.junit.Assert;
35 import org.junit.Before;
36 import org.junit.Test;
37
38 public class KMeansPlusPlusClustererTest {
39
40 private RandomGenerator random;
41
42 @Before
43 public void setUp() {
44 random = new JDKRandomGenerator();
45 random.setSeed(1746432956321l);
46 }
47
48
49
50
51
52
53 @Test
54 public void testPerformClusterAnalysisDegenerate() {
55 KMeansPlusPlusClusterer<DoublePoint> transformer =
56 new KMeansPlusPlusClusterer<DoublePoint>(1, 1);
57
58 DoublePoint[] points = new DoublePoint[] {
59 new DoublePoint(new int[] { 1959, 325100 }),
60 new DoublePoint(new int[] { 1960, 373200 }), };
61 List<? extends Cluster<DoublePoint>> clusters = transformer.cluster(Arrays.asList(points));
62 Assert.assertEquals(1, clusters.size());
63 Assert.assertEquals(2, (clusters.get(0).getPoints().size()));
64 DoublePoint pt1 = new DoublePoint(new int[] { 1959, 325100 });
65 DoublePoint pt2 = new DoublePoint(new int[] { 1960, 373200 });
66 Assert.assertTrue(clusters.get(0).getPoints().contains(pt1));
67 Assert.assertTrue(clusters.get(0).getPoints().contains(pt2));
68
69 }
70
71 @Test
72 public void testCertainSpace() {
73 KMeansPlusPlusClusterer.EmptyClusterStrategy[] strategies = {
74 KMeansPlusPlusClusterer.EmptyClusterStrategy.LARGEST_VARIANCE,
75 KMeansPlusPlusClusterer.EmptyClusterStrategy.LARGEST_POINTS_NUMBER,
76 KMeansPlusPlusClusterer.EmptyClusterStrategy.FARTHEST_POINT
77 };
78 for (KMeansPlusPlusClusterer.EmptyClusterStrategy strategy : strategies) {
79 int numberOfVariables = 27;
80
81 int position1 = 1;
82 int position2 = position1 + numberOfVariables;
83 int position3 = position2 + numberOfVariables;
84 int position4 = position3 + numberOfVariables;
85
86 int multiplier = 1000000;
87
88 DoublePoint[] breakingPoints = new DoublePoint[numberOfVariables];
89
90 for (int i = 0; i < numberOfVariables; i++) {
91 int[] points = { position1, position2, position3, position4 };
92
93 for (int j = 0; j < points.length; j++) {
94 points[j] *= multiplier;
95 }
96 DoublePoint DoublePoint = new DoublePoint(points);
97 breakingPoints[i] = DoublePoint;
98 position1 += numberOfVariables;
99 position2 += numberOfVariables;
100 position3 += numberOfVariables;
101 position4 += numberOfVariables;
102 }
103
104 for (int n = 2; n < 27; ++n) {
105 KMeansPlusPlusClusterer<DoublePoint> transformer =
106 new KMeansPlusPlusClusterer<DoublePoint>(n, 100, new EuclideanDistance(), random, strategy);
107
108 List<? extends Cluster<DoublePoint>> clusters =
109 transformer.cluster(Arrays.asList(breakingPoints));
110
111 Assert.assertEquals(n, clusters.size());
112 int sum = 0;
113 for (Cluster<DoublePoint> cluster : clusters) {
114 sum += cluster.getPoints().size();
115 }
116 Assert.assertEquals(numberOfVariables, sum);
117 }
118 }
119
120 }
121
122
123
124
125
126 private class CloseDistance extends EuclideanDistance {
127 private static final long serialVersionUID = 1L;
128
129 @Override
130 public double compute(double[] a, double[] b) {
131 return super.compute(a, b) * 0.001;
132 }
133 }
134
135
136
137
138 @Test
139 public void testSmallDistances() {
140
141
142 int[] repeatedArray = { 0 };
143 int[] uniqueArray = { 1 };
144 DoublePoint repeatedPoint = new DoublePoint(repeatedArray);
145 DoublePoint uniquePoint = new DoublePoint(uniqueArray);
146
147 Collection<DoublePoint> points = new ArrayList<DoublePoint>();
148 final int NUM_REPEATED_POINTS = 10 * 1000;
149 for (int i = 0; i < NUM_REPEATED_POINTS; ++i) {
150 points.add(repeatedPoint);
151 }
152 points.add(uniquePoint);
153
154
155
156 final long RANDOM_SEED = 0;
157 final int NUM_CLUSTERS = 2;
158 final int NUM_ITERATIONS = 0;
159 random.setSeed(RANDOM_SEED);
160
161 KMeansPlusPlusClusterer<DoublePoint> clusterer =
162 new KMeansPlusPlusClusterer<DoublePoint>(NUM_CLUSTERS, NUM_ITERATIONS,
163 new CloseDistance(), random);
164 List<CentroidCluster<DoublePoint>> clusters = clusterer.cluster(points);
165
166
167 boolean uniquePointIsCenter = false;
168 for (CentroidCluster<DoublePoint> cluster : clusters) {
169 if (cluster.getCenter().equals(uniquePoint)) {
170 uniquePointIsCenter = true;
171 }
172 }
173 Assert.assertTrue(uniquePointIsCenter);
174 }
175
176
177
178
179 @Test(expected=MathIllegalArgumentException.class)
180 public void testPerformClusterAnalysisToManyClusters() {
181 KMeansPlusPlusClusterer<DoublePoint> transformer =
182 new KMeansPlusPlusClusterer<DoublePoint>(3, 1, new EuclideanDistance(), random);
183
184 DoublePoint[] points = new DoublePoint[] {
185 new DoublePoint(new int[] {
186 1959, 325100
187 }), new DoublePoint(new int[] {
188 1960, 373200
189 })
190 };
191
192 transformer.cluster(Arrays.asList(points));
193
194 }
195
196 }