1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * https://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 /*
19 * This is not the original file distributed by the Apache Software Foundation
20 * It has been modified by the Hipparchus project
21 */
22
23 package org.hipparchus.clustering;
24
25 import java.util.ArrayList;
26 import java.util.Collection;
27 import java.util.Collections;
28 import java.util.List;
29
30 import org.hipparchus.clustering.distance.DistanceMeasure;
31 import org.hipparchus.clustering.distance.EuclideanDistance;
32 import org.hipparchus.exception.LocalizedCoreFormats;
33 import org.hipparchus.exception.MathIllegalArgumentException;
34 import org.hipparchus.exception.MathIllegalStateException;
35 import org.hipparchus.random.JDKRandomGenerator;
36 import org.hipparchus.random.RandomGenerator;
37 import org.hipparchus.stat.descriptive.moment.Variance;
38 import org.hipparchus.util.MathUtils;
39
40 /**
41 * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm.
42 * @param <T> type of the points to cluster
43 * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a>
44 */
45 public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T> {
46
47 /** Strategies to use for replacing an empty cluster. */
48 public enum EmptyClusterStrategy {
49
50 /** Split the cluster with largest distance variance. */
51 LARGEST_VARIANCE,
52
53 /** Split the cluster with largest number of points. */
54 LARGEST_POINTS_NUMBER,
55
56 /** Create a cluster around the point farthest from its centroid. */
57 FARTHEST_POINT,
58
59 /** Generate an error. */
60 ERROR
61
62 }
63
64 /** The number of clusters. */
65 private final int k;
66
67 /** The maximum number of iterations. */
68 private final int maxIterations;
69
70 /** Random generator for choosing initial centers. */
71 private final RandomGenerator random;
72
73 /** Selected strategy for empty clusters. */
74 private final EmptyClusterStrategy emptyStrategy;
75
76 /** Build a clusterer.
77 * <p>
78 * The default strategy for handling empty clusters that may appear during
79 * algorithm iterations is to split the cluster with largest distance variance.
80 * <p>
81 * The euclidean distance will be used as default distance measure.
82 *
83 * @param k the number of clusters to split the data into
84 */
85 public KMeansPlusPlusClusterer(final int k) {
86 this(k, -1);
87 }
88
89 /** Build a clusterer.
90 * <p>
91 * The default strategy for handling empty clusters that may appear during
92 * algorithm iterations is to split the cluster with largest distance variance.
93 * <p>
94 * The euclidean distance will be used as default distance measure.
95 *
96 * @param k the number of clusters to split the data into
97 * @param maxIterations the maximum number of iterations to run the algorithm for.
98 * If negative, no maximum will be used.
99 */
100 public KMeansPlusPlusClusterer(final int k, final int maxIterations) {
101 this(k, maxIterations, new EuclideanDistance());
102 }
103
104 /** Build a clusterer.
105 * <p>
106 * The default strategy for handling empty clusters that may appear during
107 * algorithm iterations is to split the cluster with largest distance variance.
108 *
109 * @param k the number of clusters to split the data into
110 * @param maxIterations the maximum number of iterations to run the algorithm for.
111 * If negative, no maximum will be used.
112 * @param measure the distance measure to use
113 */
114 public KMeansPlusPlusClusterer(final int k, final int maxIterations, final DistanceMeasure measure) {
115 this(k, maxIterations, measure, new JDKRandomGenerator());
116 }
117
118 /** Build a clusterer.
119 * <p>
120 * The default strategy for handling empty clusters that may appear during
121 * algorithm iterations is to split the cluster with largest distance variance.
122 *
123 * @param k the number of clusters to split the data into
124 * @param maxIterations the maximum number of iterations to run the algorithm for.
125 * If negative, no maximum will be used.
126 * @param measure the distance measure to use
127 * @param random random generator to use for choosing initial centers
128 */
129 public KMeansPlusPlusClusterer(final int k, final int maxIterations,
130 final DistanceMeasure measure,
131 final RandomGenerator random) {
132 this(k, maxIterations, measure, random, EmptyClusterStrategy.LARGEST_VARIANCE);
133 }
134
135 /** Build a clusterer.
136 *
137 * @param k the number of clusters to split the data into
138 * @param maxIterations the maximum number of iterations to run the algorithm for.
139 * If negative, no maximum will be used.
140 * @param measure the distance measure to use
141 * @param random random generator to use for choosing initial centers
142 * @param emptyStrategy strategy to use for handling empty clusters that
143 * may appear during algorithm iterations
144 */
145 public KMeansPlusPlusClusterer(final int k, final int maxIterations,
146 final DistanceMeasure measure,
147 final RandomGenerator random,
148 final EmptyClusterStrategy emptyStrategy) {
149 super(measure);
150 this.k = k;
151 this.maxIterations = maxIterations;
152 this.random = random;
153 this.emptyStrategy = emptyStrategy;
154 }
155
156 /**
157 * Return the number of clusters this instance will use.
158 * @return the number of clusters
159 */
160 public int getK() {
161 return k;
162 }
163
164 /**
165 * Returns the maximum number of iterations this instance will use.
166 * @return the maximum number of iterations, or -1 if no maximum is set
167 */
168 public int getMaxIterations() {
169 return maxIterations;
170 }
171
172 /**
173 * Returns the random generator this instance will use.
174 * @return the random generator
175 */
176 public RandomGenerator getRandomGenerator() {
177 return random;
178 }
179
180 /**
181 * Returns the {@link EmptyClusterStrategy} used by this instance.
182 * @return the {@link EmptyClusterStrategy}
183 */
184 public EmptyClusterStrategy getEmptyClusterStrategy() {
185 return emptyStrategy;
186 }
187
188 /**
189 * Runs the K-means++ clustering algorithm.
190 *
191 * @param points the points to cluster
192 * @return a list of clusters containing the points
193 * @throws MathIllegalArgumentException if the data points are null or the number
194 * of clusters is larger than the number of data points
195 * @throws MathIllegalStateException if an empty cluster is encountered and the
196 * {@link #emptyStrategy} is set to {@code ERROR}
197 */
198 @Override
199 public List<CentroidCluster<T>> cluster(final Collection<T> points)
200 throws MathIllegalArgumentException, MathIllegalStateException {
201
202 // sanity checks
203 MathUtils.checkNotNull(points);
204
205 // number of clusters has to be smaller or equal the number of data points
206 if (points.size() < k) {
207 throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL_BOUND_EXCLUDED,
208 points.size(), k);
209 }
210
211 // create the initial clusters
212 List<CentroidCluster<T>> clusters = chooseInitialCenters(points);
213
214 // create an array containing the latest assignment of a point to a cluster
215 // no need to initialize the array, as it will be filled with the first assignment
216 int[] assignments = new int[points.size()];
217 assignPointsToClusters(clusters, points, assignments);
218
219 // iterate through updating the centers until we're done
220 final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
221 for (int count = 0; count < max; count++) {
222 boolean emptyCluster = false;
223 List<CentroidCluster<T>> newClusters = new ArrayList<>();
224 for (final CentroidCluster<T> cluster : clusters) {
225 final Clusterable newCenter;
226 if (cluster.getPoints().isEmpty()) {
227 switch (emptyStrategy) {
228 case LARGEST_VARIANCE :
229 newCenter = getPointFromLargestVarianceCluster(clusters);
230 break;
231 case LARGEST_POINTS_NUMBER :
232 newCenter = getPointFromLargestNumberCluster(clusters);
233 break;
234 case FARTHEST_POINT :
235 newCenter = getFarthestPoint(clusters);
236 break;
237 default :
238 throw new MathIllegalStateException(LocalizedClusteringFormats.EMPTY_CLUSTER_IN_K_MEANS);
239 }
240 emptyCluster = true;
241 } else {
242 newCenter = centroidOf(cluster.getPoints(), cluster.getCenter().getPoint().length);
243 }
244 newClusters.add(new CentroidCluster<>(newCenter));
245 }
246 int changes = assignPointsToClusters(newClusters, points, assignments);
247 clusters = newClusters;
248
249 // if there were no more changes in the point-to-cluster assignment
250 // and there are no empty clusters left, return the current clusters
251 if (changes == 0 && !emptyCluster) {
252 return clusters;
253 }
254 }
255 return clusters;
256 }
257
258 /**
259 * Adds the given points to the closest {@link Cluster}.
260 *
261 * @param clusters the {@link Cluster}s to add the points to
262 * @param points the points to add to the given {@link Cluster}s
263 * @param assignments points assignments to clusters
264 * @return the number of points assigned to different clusters as the iteration before
265 */
266 private int assignPointsToClusters(final List<CentroidCluster<T>> clusters,
267 final Collection<T> points,
268 final int[] assignments) {
269 int assignedDifferently = 0;
270 int pointIndex = 0;
271 for (final T p : points) {
272 int clusterIndex = getNearestCluster(clusters, p);
273 if (clusterIndex != assignments[pointIndex]) {
274 assignedDifferently++;
275 }
276
277 CentroidCluster<T> cluster = clusters.get(clusterIndex);
278 cluster.addPoint(p);
279 assignments[pointIndex++] = clusterIndex;
280 }
281
282 return assignedDifferently;
283 }
284
285 /**
286 * Use K-means++ to choose the initial centers.
287 *
288 * @param points the points to choose the initial centers from
289 * @return the initial centers
290 */
291 private List<CentroidCluster<T>> chooseInitialCenters(final Collection<T> points) {
292
293 // Convert to list for indexed access. Make it unmodifiable, since removal of items
294 // would screw up the logic of this method.
295 final List<T> pointList = Collections.unmodifiableList(new ArrayList<>(points));
296
297 // The number of points in the list.
298 final int numPoints = pointList.size();
299
300 // Set the corresponding element in this array to indicate when
301 // elements of pointList are no longer available.
302 final boolean[] taken = new boolean[numPoints];
303
304 // The resulting list of initial centers.
305 final List<CentroidCluster<T>> resultSet = new ArrayList<>();
306
307 // Choose one center uniformly at random from among the data points.
308 final int firstPointIndex = random.nextInt(numPoints);
309
310 final T firstPoint = pointList.get(firstPointIndex);
311
312 resultSet.add(new CentroidCluster<>(firstPoint));
313
314 // Must mark it as taken
315 taken[firstPointIndex] = true;
316
317 // To keep track of the minimum distance squared of elements of
318 // pointList to elements of resultSet.
319 final double[] minDistSquared = new double[numPoints];
320
321 // Initialize the elements. Since the only point in resultSet is firstPoint,
322 // this is very easy.
323 for (int i = 0; i < numPoints; i++) {
324 if (i != firstPointIndex) { // That point isn't considered
325 double d = distance(firstPoint, pointList.get(i));
326 minDistSquared[i] = d*d;
327 }
328 }
329
330 while (resultSet.size() < k) {
331
332 // Sum up the squared distances for the points in pointList not
333 // already taken.
334 double distSqSum = 0.0;
335
336 for (int i = 0; i < numPoints; i++) {
337 if (!taken[i]) {
338 distSqSum += minDistSquared[i];
339 }
340 }
341
342 // Add one new data point as a center. Each point x is chosen with
343 // probability proportional to D(x)2
344 final double r = random.nextDouble() * distSqSum;
345
346 // The index of the next point to be added to the resultSet.
347 int nextPointIndex = -1;
348
349 // Sum through the squared min distances again, stopping when
350 // sum >= r.
351 double sum = 0.0;
352 for (int i = 0; i < numPoints; i++) {
353 if (!taken[i]) {
354 sum += minDistSquared[i];
355 if (sum >= r) {
356 nextPointIndex = i;
357 break;
358 }
359 }
360 }
361
362 // If it's not set to >= 0, the point wasn't found in the previous
363 // for loop, probably because distances are extremely small. Just pick
364 // the last available point.
365 if (nextPointIndex == -1) {
366 for (int i = numPoints - 1; i >= 0; i--) {
367 if (!taken[i]) {
368 nextPointIndex = i;
369 break;
370 }
371 }
372 }
373
374 // We found one.
375 if (nextPointIndex >= 0) {
376
377 final T p = pointList.get(nextPointIndex);
378
379 resultSet.add(new CentroidCluster<>(p));
380
381 // Mark it as taken.
382 taken[nextPointIndex] = true;
383
384 if (resultSet.size() < k) {
385 // Now update elements of minDistSquared. We only have to compute
386 // the distance to the new center to do this.
387 for (int j = 0; j < numPoints; j++) {
388 // Only have to worry about the points still not taken.
389 if (!taken[j]) {
390 double d = distance(p, pointList.get(j));
391 double d2 = d * d;
392 if (d2 < minDistSquared[j]) {
393 minDistSquared[j] = d2;
394 }
395 }
396 }
397 }
398
399 } else {
400 // None found --
401 // Break from the while loop to prevent
402 // an infinite loop.
403 break;
404 }
405 }
406
407 return resultSet;
408 }
409
410 /**
411 * Get a random point from the {@link Cluster} with the largest distance variance.
412 *
413 * @param clusters the {@link Cluster}s to search
414 * @return a random point from the selected cluster
415 * @throws MathIllegalStateException if clusters are all empty
416 */
417 private T getPointFromLargestVarianceCluster(final Collection<CentroidCluster<T>> clusters)
418 throws MathIllegalStateException {
419
420 double maxVariance = Double.NEGATIVE_INFINITY;
421 Cluster<T> selected = null;
422 for (final CentroidCluster<T> cluster : clusters) {
423 if (!cluster.getPoints().isEmpty()) {
424
425 // compute the distance variance of the current cluster
426 final Clusterable center = cluster.getCenter();
427 final Variance stat = new Variance();
428 for (final T point : cluster.getPoints()) {
429 stat.increment(distance(point, center));
430 }
431 final double variance = stat.getResult();
432
433 // select the cluster with the largest variance
434 if (variance > maxVariance) {
435 maxVariance = variance;
436 selected = cluster;
437 }
438
439 }
440 }
441
442 // did we find at least one non-empty cluster ?
443 if (selected == null) {
444 throw new MathIllegalStateException(LocalizedClusteringFormats.EMPTY_CLUSTER_IN_K_MEANS);
445 }
446
447 // extract a random point from the cluster
448 final List<T> selectedPoints = selected.getPoints();
449 return selectedPoints.remove(random.nextInt(selectedPoints.size()));
450
451 }
452
453 /**
454 * Get a random point from the {@link Cluster} with the largest number of points
455 *
456 * @param clusters the {@link Cluster}s to search
457 * @return a random point from the selected cluster
458 * @throws MathIllegalStateException if clusters are all empty
459 */
460 private T getPointFromLargestNumberCluster(final Collection<? extends Cluster<T>> clusters)
461 throws MathIllegalStateException {
462
463 int maxNumber = 0;
464 Cluster<T> selected = null;
465 for (final Cluster<T> cluster : clusters) {
466
467 // get the number of points of the current cluster
468 final int number = cluster.getPoints().size();
469
470 // select the cluster with the largest number of points
471 if (number > maxNumber) {
472 maxNumber = number;
473 selected = cluster;
474 }
475
476 }
477
478 // did we find at least one non-empty cluster ?
479 if (selected == null) {
480 throw new MathIllegalStateException(LocalizedClusteringFormats.EMPTY_CLUSTER_IN_K_MEANS);
481 }
482
483 // extract a random point from the cluster
484 final List<T> selectedPoints = selected.getPoints();
485 return selectedPoints.remove(random.nextInt(selectedPoints.size()));
486
487 }
488
489 /**
490 * Get the point farthest to its cluster center
491 *
492 * @param clusters the {@link Cluster}s to search
493 * @return point farthest to its cluster center
494 * @throws MathIllegalStateException if clusters are all empty
495 */
496 private T getFarthestPoint(final Collection<CentroidCluster<T>> clusters) throws MathIllegalStateException {
497
498 double maxDistance = Double.NEGATIVE_INFINITY;
499 Cluster<T> selectedCluster = null;
500 int selectedPoint = -1;
501 for (final CentroidCluster<T> cluster : clusters) {
502
503 // get the farthest point
504 final Clusterable center = cluster.getCenter();
505 final List<T> points = cluster.getPoints();
506 for (int i = 0; i < points.size(); ++i) {
507 final double distance = distance(points.get(i), center);
508 if (distance > maxDistance) {
509 maxDistance = distance;
510 selectedCluster = cluster;
511 selectedPoint = i;
512 }
513 }
514
515 }
516
517 // did we find at least one non-empty cluster ?
518 if (selectedCluster == null) {
519 throw new MathIllegalStateException(LocalizedClusteringFormats.EMPTY_CLUSTER_IN_K_MEANS);
520 }
521
522 return selectedCluster.getPoints().remove(selectedPoint);
523
524 }
525
526 /**
527 * Returns the nearest {@link Cluster} to the given point
528 *
529 * @param clusters the {@link Cluster}s to search
530 * @param point the point to find the nearest {@link Cluster} for
531 * @return the index of the nearest {@link Cluster} to the given point
532 */
533 private int getNearestCluster(final Collection<CentroidCluster<T>> clusters, final T point) {
534 double minDistance = Double.MAX_VALUE;
535 int clusterIndex = 0;
536 int minCluster = 0;
537 for (final CentroidCluster<T> c : clusters) {
538 final double distance = distance(point, c.getCenter());
539 if (distance < minDistance) {
540 minDistance = distance;
541 minCluster = clusterIndex;
542 }
543 clusterIndex++;
544 }
545 return minCluster;
546 }
547
548 /**
549 * Computes the centroid for a set of points.
550 *
551 * @param points the set of points
552 * @param dimension the point dimension
553 * @return the computed centroid for the set of points
554 */
555 private Clusterable centroidOf(final Collection<T> points, final int dimension) {
556 final double[] centroid = new double[dimension];
557 for (final T p : points) {
558 final double[] point = p.getPoint();
559 for (int i = 0; i < centroid.length; i++) {
560 centroid[i] += point[i];
561 }
562 }
563 for (int i = 0; i < centroid.length; i++) {
564 centroid[i] /= points.size();
565 }
566 return new DoublePoint(centroid);
567 }
568
569 }