1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23 package org.hipparchus.clustering;
24
25 import java.util.ArrayList;
26 import java.util.Collection;
27 import java.util.Collections;
28 import java.util.List;
29
30 import org.hipparchus.clustering.distance.DistanceMeasure;
31 import org.hipparchus.clustering.distance.EuclideanDistance;
32 import org.hipparchus.exception.LocalizedCoreFormats;
33 import org.hipparchus.exception.MathIllegalArgumentException;
34 import org.hipparchus.exception.MathIllegalStateException;
35 import org.hipparchus.random.JDKRandomGenerator;
36 import org.hipparchus.random.RandomGenerator;
37 import org.hipparchus.stat.descriptive.moment.Variance;
38 import org.hipparchus.util.MathUtils;
39
40
41
42
43
44
45 public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T> {
46
47
48 public enum EmptyClusterStrategy {
49
50
51 LARGEST_VARIANCE,
52
53
54 LARGEST_POINTS_NUMBER,
55
56
57 FARTHEST_POINT,
58
59
60 ERROR
61
62 }
63
64
65 private final int k;
66
67
68 private final int maxIterations;
69
70
71 private final RandomGenerator random;
72
73
74 private final EmptyClusterStrategy emptyStrategy;
75
76
77
78
79
80
81
82
83
84
85 public KMeansPlusPlusClusterer(final int k) {
86 this(k, -1);
87 }
88
89
90
91
92
93
94
95
96
97
98
99
100 public KMeansPlusPlusClusterer(final int k, final int maxIterations) {
101 this(k, maxIterations, new EuclideanDistance());
102 }
103
104
105
106
107
108
109
110
111
112
113
114 public KMeansPlusPlusClusterer(final int k, final int maxIterations, final DistanceMeasure measure) {
115 this(k, maxIterations, measure, new JDKRandomGenerator());
116 }
117
118
119
120
121
122
123
124
125
126
127
128
129 public KMeansPlusPlusClusterer(final int k, final int maxIterations,
130 final DistanceMeasure measure,
131 final RandomGenerator random) {
132 this(k, maxIterations, measure, random, EmptyClusterStrategy.LARGEST_VARIANCE);
133 }
134
135
136
137
138
139
140
141
142
143
144
145 public KMeansPlusPlusClusterer(final int k, final int maxIterations,
146 final DistanceMeasure measure,
147 final RandomGenerator random,
148 final EmptyClusterStrategy emptyStrategy) {
149 super(measure);
150 this.k = k;
151 this.maxIterations = maxIterations;
152 this.random = random;
153 this.emptyStrategy = emptyStrategy;
154 }
155
156
157
158
159
160 public int getK() {
161 return k;
162 }
163
164
165
166
167
168 public int getMaxIterations() {
169 return maxIterations;
170 }
171
172
173
174
175
176 public RandomGenerator getRandomGenerator() {
177 return random;
178 }
179
180
181
182
183
184 public EmptyClusterStrategy getEmptyClusterStrategy() {
185 return emptyStrategy;
186 }
187
188
189
190
191
192
193
194
195
196
197
198 @Override
199 public List<CentroidCluster<T>> cluster(final Collection<T> points)
200 throws MathIllegalArgumentException, MathIllegalStateException {
201
202
203 MathUtils.checkNotNull(points);
204
205
206 if (points.size() < k) {
207 throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL_BOUND_EXCLUDED,
208 points.size(), k);
209 }
210
211
212 List<CentroidCluster<T>> clusters = chooseInitialCenters(points);
213
214
215
216 int[] assignments = new int[points.size()];
217 assignPointsToClusters(clusters, points, assignments);
218
219
220 final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
221 for (int count = 0; count < max; count++) {
222 boolean emptyCluster = false;
223 List<CentroidCluster<T>> newClusters = new ArrayList<>();
224 for (final CentroidCluster<T> cluster : clusters) {
225 final Clusterable newCenter;
226 if (cluster.getPoints().isEmpty()) {
227 switch (emptyStrategy) {
228 case LARGEST_VARIANCE :
229 newCenter = getPointFromLargestVarianceCluster(clusters);
230 break;
231 case LARGEST_POINTS_NUMBER :
232 newCenter = getPointFromLargestNumberCluster(clusters);
233 break;
234 case FARTHEST_POINT :
235 newCenter = getFarthestPoint(clusters);
236 break;
237 default :
238 throw new MathIllegalStateException(LocalizedClusteringFormats.EMPTY_CLUSTER_IN_K_MEANS);
239 }
240 emptyCluster = true;
241 } else {
242 newCenter = centroidOf(cluster.getPoints(), cluster.getCenter().getPoint().length);
243 }
244 newClusters.add(new CentroidCluster<T>(newCenter));
245 }
246 int changes = assignPointsToClusters(newClusters, points, assignments);
247 clusters = newClusters;
248
249
250
251 if (changes == 0 && !emptyCluster) {
252 return clusters;
253 }
254 }
255 return clusters;
256 }
257
258
259
260
261
262
263
264
265
266 private int assignPointsToClusters(final List<CentroidCluster<T>> clusters,
267 final Collection<T> points,
268 final int[] assignments) {
269 int assignedDifferently = 0;
270 int pointIndex = 0;
271 for (final T p : points) {
272 int clusterIndex = getNearestCluster(clusters, p);
273 if (clusterIndex != assignments[pointIndex]) {
274 assignedDifferently++;
275 }
276
277 CentroidCluster<T> cluster = clusters.get(clusterIndex);
278 cluster.addPoint(p);
279 assignments[pointIndex++] = clusterIndex;
280 }
281
282 return assignedDifferently;
283 }
284
285
286
287
288
289
290
291 private List<CentroidCluster<T>> chooseInitialCenters(final Collection<T> points) {
292
293
294
295 final List<T> pointList = Collections.unmodifiableList(new ArrayList<T> (points));
296
297
298 final int numPoints = pointList.size();
299
300
301
302 final boolean[] taken = new boolean[numPoints];
303
304
305 final List<CentroidCluster<T>> resultSet = new ArrayList<>();
306
307
308 final int firstPointIndex = random.nextInt(numPoints);
309
310 final T firstPoint = pointList.get(firstPointIndex);
311
312 resultSet.add(new CentroidCluster<T>(firstPoint));
313
314
315 taken[firstPointIndex] = true;
316
317
318
319 final double[] minDistSquared = new double[numPoints];
320
321
322
323 for (int i = 0; i < numPoints; i++) {
324 if (i != firstPointIndex) {
325 double d = distance(firstPoint, pointList.get(i));
326 minDistSquared[i] = d*d;
327 }
328 }
329
330 while (resultSet.size() < k) {
331
332
333
334 double distSqSum = 0.0;
335
336 for (int i = 0; i < numPoints; i++) {
337 if (!taken[i]) {
338 distSqSum += minDistSquared[i];
339 }
340 }
341
342
343
344 final double r = random.nextDouble() * distSqSum;
345
346
347 int nextPointIndex = -1;
348
349
350
351 double sum = 0.0;
352 for (int i = 0; i < numPoints; i++) {
353 if (!taken[i]) {
354 sum += minDistSquared[i];
355 if (sum >= r) {
356 nextPointIndex = i;
357 break;
358 }
359 }
360 }
361
362
363
364
365 if (nextPointIndex == -1) {
366 for (int i = numPoints - 1; i >= 0; i--) {
367 if (!taken[i]) {
368 nextPointIndex = i;
369 break;
370 }
371 }
372 }
373
374
375 if (nextPointIndex >= 0) {
376
377 final T p = pointList.get(nextPointIndex);
378
379 resultSet.add(new CentroidCluster<T> (p));
380
381
382 taken[nextPointIndex] = true;
383
384 if (resultSet.size() < k) {
385
386
387 for (int j = 0; j < numPoints; j++) {
388
389 if (!taken[j]) {
390 double d = distance(p, pointList.get(j));
391 double d2 = d * d;
392 if (d2 < minDistSquared[j]) {
393 minDistSquared[j] = d2;
394 }
395 }
396 }
397 }
398
399 } else {
400
401
402
403 break;
404 }
405 }
406
407 return resultSet;
408 }
409
410
411
412
413
414
415
416
417 private T getPointFromLargestVarianceCluster(final Collection<CentroidCluster<T>> clusters)
418 throws MathIllegalStateException {
419
420 double maxVariance = Double.NEGATIVE_INFINITY;
421 Cluster<T> selected = null;
422 for (final CentroidCluster<T> cluster : clusters) {
423 if (!cluster.getPoints().isEmpty()) {
424
425
426 final Clusterable center = cluster.getCenter();
427 final Variance stat = new Variance();
428 for (final T point : cluster.getPoints()) {
429 stat.increment(distance(point, center));
430 }
431 final double variance = stat.getResult();
432
433
434 if (variance > maxVariance) {
435 maxVariance = variance;
436 selected = cluster;
437 }
438
439 }
440 }
441
442
443 if (selected == null) {
444 throw new MathIllegalStateException(LocalizedClusteringFormats.EMPTY_CLUSTER_IN_K_MEANS);
445 }
446
447
448 final List<T> selectedPoints = selected.getPoints();
449 return selectedPoints.remove(random.nextInt(selectedPoints.size()));
450
451 }
452
453
454
455
456
457
458
459
460 private T getPointFromLargestNumberCluster(final Collection<? extends Cluster<T>> clusters)
461 throws MathIllegalStateException {
462
463 int maxNumber = 0;
464 Cluster<T> selected = null;
465 for (final Cluster<T> cluster : clusters) {
466
467
468 final int number = cluster.getPoints().size();
469
470
471 if (number > maxNumber) {
472 maxNumber = number;
473 selected = cluster;
474 }
475
476 }
477
478
479 if (selected == null) {
480 throw new MathIllegalStateException(LocalizedClusteringFormats.EMPTY_CLUSTER_IN_K_MEANS);
481 }
482
483
484 final List<T> selectedPoints = selected.getPoints();
485 return selectedPoints.remove(random.nextInt(selectedPoints.size()));
486
487 }
488
489
490
491
492
493
494
495
496 private T getFarthestPoint(final Collection<CentroidCluster<T>> clusters) throws MathIllegalStateException {
497
498 double maxDistance = Double.NEGATIVE_INFINITY;
499 Cluster<T> selectedCluster = null;
500 int selectedPoint = -1;
501 for (final CentroidCluster<T> cluster : clusters) {
502
503
504 final Clusterable center = cluster.getCenter();
505 final List<T> points = cluster.getPoints();
506 for (int i = 0; i < points.size(); ++i) {
507 final double distance = distance(points.get(i), center);
508 if (distance > maxDistance) {
509 maxDistance = distance;
510 selectedCluster = cluster;
511 selectedPoint = i;
512 }
513 }
514
515 }
516
517
518 if (selectedCluster == null) {
519 throw new MathIllegalStateException(LocalizedClusteringFormats.EMPTY_CLUSTER_IN_K_MEANS);
520 }
521
522 return selectedCluster.getPoints().remove(selectedPoint);
523
524 }
525
526
527
528
529
530
531
532
533 private int getNearestCluster(final Collection<CentroidCluster<T>> clusters, final T point) {
534 double minDistance = Double.MAX_VALUE;
535 int clusterIndex = 0;
536 int minCluster = 0;
537 for (final CentroidCluster<T> c : clusters) {
538 final double distance = distance(point, c.getCenter());
539 if (distance < minDistance) {
540 minDistance = distance;
541 minCluster = clusterIndex;
542 }
543 clusterIndex++;
544 }
545 return minCluster;
546 }
547
548
549
550
551
552
553
554
555 private Clusterable centroidOf(final Collection<T> points, final int dimension) {
556 final double[] centroid = new double[dimension];
557 for (final T p : points) {
558 final double[] point = p.getPoint();
559 for (int i = 0; i < centroid.length; i++) {
560 centroid[i] += point[i];
561 }
562 }
563 for (int i = 0; i < centroid.length; i++) {
564 centroid[i] /= points.size();
565 }
566 return new DoublePoint(centroid);
567 }
568
569 }