View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  /*
19   * This is not the original file distributed by the Apache Software Foundation
20   * It has been modified by the Hipparchus project
21   */
22  
23  package org.hipparchus.clustering;
24  
25  import java.util.ArrayList;
26  import java.util.Collection;
27  import java.util.Collections;
28  import java.util.List;
29  
30  import org.hipparchus.clustering.distance.DistanceMeasure;
31  import org.hipparchus.clustering.distance.EuclideanDistance;
32  import org.hipparchus.exception.LocalizedCoreFormats;
33  import org.hipparchus.exception.MathIllegalArgumentException;
34  import org.hipparchus.exception.MathIllegalStateException;
35  import org.hipparchus.random.JDKRandomGenerator;
36  import org.hipparchus.random.RandomGenerator;
37  import org.hipparchus.stat.descriptive.moment.Variance;
38  import org.hipparchus.util.MathUtils;
39  
40  /**
41   * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm.
42   * @param <T> type of the points to cluster
43   * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a>
44   */
45  public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T> {
46  
47      /** Strategies to use for replacing an empty cluster. */
48      public enum EmptyClusterStrategy {
49  
50          /** Split the cluster with largest distance variance. */
51          LARGEST_VARIANCE,
52  
53          /** Split the cluster with largest number of points. */
54          LARGEST_POINTS_NUMBER,
55  
56          /** Create a cluster around the point farthest from its centroid. */
57          FARTHEST_POINT,
58  
59          /** Generate an error. */
60          ERROR
61  
62      }
63  
64      /** The number of clusters. */
65      private final int k;
66  
67      /** The maximum number of iterations. */
68      private final int maxIterations;
69  
70      /** Random generator for choosing initial centers. */
71      private final RandomGenerator random;
72  
73      /** Selected strategy for empty clusters. */
74      private final EmptyClusterStrategy emptyStrategy;
75  
76      /** Build a clusterer.
77       * <p>
78       * The default strategy for handling empty clusters that may appear during
79       * algorithm iterations is to split the cluster with largest distance variance.
80       * <p>
81       * The euclidean distance will be used as default distance measure.
82       *
83       * @param k the number of clusters to split the data into
84       */
85      public KMeansPlusPlusClusterer(final int k) {
86          this(k, -1);
87      }
88  
89      /** Build a clusterer.
90       * <p>
91       * The default strategy for handling empty clusters that may appear during
92       * algorithm iterations is to split the cluster with largest distance variance.
93       * <p>
94       * The euclidean distance will be used as default distance measure.
95       *
96       * @param k the number of clusters to split the data into
97       * @param maxIterations the maximum number of iterations to run the algorithm for.
98       *   If negative, no maximum will be used.
99       */
100     public KMeansPlusPlusClusterer(final int k, final int maxIterations) {
101         this(k, maxIterations, new EuclideanDistance());
102     }
103 
104     /** Build a clusterer.
105      * <p>
106      * The default strategy for handling empty clusters that may appear during
107      * algorithm iterations is to split the cluster with largest distance variance.
108      *
109      * @param k the number of clusters to split the data into
110      * @param maxIterations the maximum number of iterations to run the algorithm for.
111      *   If negative, no maximum will be used.
112      * @param measure the distance measure to use
113      */
114     public KMeansPlusPlusClusterer(final int k, final int maxIterations, final DistanceMeasure measure) {
115         this(k, maxIterations, measure, new JDKRandomGenerator());
116     }
117 
118     /** Build a clusterer.
119      * <p>
120      * The default strategy for handling empty clusters that may appear during
121      * algorithm iterations is to split the cluster with largest distance variance.
122      *
123      * @param k the number of clusters to split the data into
124      * @param maxIterations the maximum number of iterations to run the algorithm for.
125      *   If negative, no maximum will be used.
126      * @param measure the distance measure to use
127      * @param random random generator to use for choosing initial centers
128      */
129     public KMeansPlusPlusClusterer(final int k, final int maxIterations,
130                                    final DistanceMeasure measure,
131                                    final RandomGenerator random) {
132         this(k, maxIterations, measure, random, EmptyClusterStrategy.LARGEST_VARIANCE);
133     }
134 
135     /** Build a clusterer.
136      *
137      * @param k the number of clusters to split the data into
138      * @param maxIterations the maximum number of iterations to run the algorithm for.
139      *   If negative, no maximum will be used.
140      * @param measure the distance measure to use
141      * @param random random generator to use for choosing initial centers
142      * @param emptyStrategy strategy to use for handling empty clusters that
143      * may appear during algorithm iterations
144      */
145     public KMeansPlusPlusClusterer(final int k, final int maxIterations,
146                                    final DistanceMeasure measure,
147                                    final RandomGenerator random,
148                                    final EmptyClusterStrategy emptyStrategy) {
149         super(measure);
150         this.k             = k;
151         this.maxIterations = maxIterations;
152         this.random        = random;
153         this.emptyStrategy = emptyStrategy;
154     }
155 
156     /**
157      * Return the number of clusters this instance will use.
158      * @return the number of clusters
159      */
160     public int getK() {
161         return k;
162     }
163 
164     /**
165      * Returns the maximum number of iterations this instance will use.
166      * @return the maximum number of iterations, or -1 if no maximum is set
167      */
168     public int getMaxIterations() {
169         return maxIterations;
170     }
171 
172     /**
173      * Returns the random generator this instance will use.
174      * @return the random generator
175      */
176     public RandomGenerator getRandomGenerator() {
177         return random;
178     }
179 
180     /**
181      * Returns the {@link EmptyClusterStrategy} used by this instance.
182      * @return the {@link EmptyClusterStrategy}
183      */
184     public EmptyClusterStrategy getEmptyClusterStrategy() {
185         return emptyStrategy;
186     }
187 
188     /**
189      * Runs the K-means++ clustering algorithm.
190      *
191      * @param points the points to cluster
192      * @return a list of clusters containing the points
193      * @throws MathIllegalArgumentException if the data points are null or the number
194      *     of clusters is larger than the number of data points
195      * @throws MathIllegalStateException if an empty cluster is encountered and the
196      * {@link #emptyStrategy} is set to {@code ERROR}
197      */
198     @Override
199     public List<CentroidCluster<T>> cluster(final Collection<T> points)
200         throws MathIllegalArgumentException, MathIllegalStateException {
201 
202         // sanity checks
203         MathUtils.checkNotNull(points);
204 
205         // number of clusters has to be smaller or equal the number of data points
206         if (points.size() < k) {
207             throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL_BOUND_EXCLUDED,
208                                                    points.size(), k);
209         }
210 
211         // create the initial clusters
212         List<CentroidCluster<T>> clusters = chooseInitialCenters(points);
213 
214         // create an array containing the latest assignment of a point to a cluster
215         // no need to initialize the array, as it will be filled with the first assignment
216         int[] assignments = new int[points.size()];
217         assignPointsToClusters(clusters, points, assignments);
218 
219         // iterate through updating the centers until we're done
220         final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
221         for (int count = 0; count < max; count++) {
222             boolean emptyCluster = false;
223             List<CentroidCluster<T>> newClusters = new ArrayList<>();
224             for (final CentroidCluster<T> cluster : clusters) {
225                 final Clusterable newCenter;
226                 if (cluster.getPoints().isEmpty()) {
227                     switch (emptyStrategy) {
228                         case LARGEST_VARIANCE :
229                             newCenter = getPointFromLargestVarianceCluster(clusters);
230                             break;
231                         case LARGEST_POINTS_NUMBER :
232                             newCenter = getPointFromLargestNumberCluster(clusters);
233                             break;
234                         case FARTHEST_POINT :
235                             newCenter = getFarthestPoint(clusters);
236                             break;
237                         default :
238                             throw new MathIllegalStateException(LocalizedClusteringFormats.EMPTY_CLUSTER_IN_K_MEANS);
239                     }
240                     emptyCluster = true;
241                 } else {
242                     newCenter = centroidOf(cluster.getPoints(), cluster.getCenter().getPoint().length);
243                 }
244                 newClusters.add(new CentroidCluster<T>(newCenter));
245             }
246             int changes = assignPointsToClusters(newClusters, points, assignments);
247             clusters = newClusters;
248 
249             // if there were no more changes in the point-to-cluster assignment
250             // and there are no empty clusters left, return the current clusters
251             if (changes == 0 && !emptyCluster) {
252                 return clusters;
253             }
254         }
255         return clusters;
256     }
257 
258     /**
259      * Adds the given points to the closest {@link Cluster}.
260      *
261      * @param clusters the {@link Cluster}s to add the points to
262      * @param points the points to add to the given {@link Cluster}s
263      * @param assignments points assignments to clusters
264      * @return the number of points assigned to different clusters as the iteration before
265      */
266     private int assignPointsToClusters(final List<CentroidCluster<T>> clusters,
267                                        final Collection<T> points,
268                                        final int[] assignments) {
269         int assignedDifferently = 0;
270         int pointIndex = 0;
271         for (final T p : points) {
272             int clusterIndex = getNearestCluster(clusters, p);
273             if (clusterIndex != assignments[pointIndex]) {
274                 assignedDifferently++;
275             }
276 
277             CentroidCluster<T> cluster = clusters.get(clusterIndex);
278             cluster.addPoint(p);
279             assignments[pointIndex++] = clusterIndex;
280         }
281 
282         return assignedDifferently;
283     }
284 
285     /**
286      * Use K-means++ to choose the initial centers.
287      *
288      * @param points the points to choose the initial centers from
289      * @return the initial centers
290      */
291     private List<CentroidCluster<T>> chooseInitialCenters(final Collection<T> points) {
292 
293         // Convert to list for indexed access. Make it unmodifiable, since removal of items
294         // would screw up the logic of this method.
295         final List<T> pointList = Collections.unmodifiableList(new ArrayList<T> (points));
296 
297         // The number of points in the list.
298         final int numPoints = pointList.size();
299 
300         // Set the corresponding element in this array to indicate when
301         // elements of pointList are no longer available.
302         final boolean[] taken = new boolean[numPoints];
303 
304         // The resulting list of initial centers.
305         final List<CentroidCluster<T>> resultSet = new ArrayList<>();
306 
307         // Choose one center uniformly at random from among the data points.
308         final int firstPointIndex = random.nextInt(numPoints);
309 
310         final T firstPoint = pointList.get(firstPointIndex);
311 
312         resultSet.add(new CentroidCluster<T>(firstPoint));
313 
314         // Must mark it as taken
315         taken[firstPointIndex] = true;
316 
317         // To keep track of the minimum distance squared of elements of
318         // pointList to elements of resultSet.
319         final double[] minDistSquared = new double[numPoints];
320 
321         // Initialize the elements.  Since the only point in resultSet is firstPoint,
322         // this is very easy.
323         for (int i = 0; i < numPoints; i++) {
324             if (i != firstPointIndex) { // That point isn't considered
325                 double d = distance(firstPoint, pointList.get(i));
326                 minDistSquared[i] = d*d;
327             }
328         }
329 
330         while (resultSet.size() < k) {
331 
332             // Sum up the squared distances for the points in pointList not
333             // already taken.
334             double distSqSum = 0.0;
335 
336             for (int i = 0; i < numPoints; i++) {
337                 if (!taken[i]) {
338                     distSqSum += minDistSquared[i];
339                 }
340             }
341 
342             // Add one new data point as a center. Each point x is chosen with
343             // probability proportional to D(x)2
344             final double r = random.nextDouble() * distSqSum;
345 
346             // The index of the next point to be added to the resultSet.
347             int nextPointIndex = -1;
348 
349             // Sum through the squared min distances again, stopping when
350             // sum >= r.
351             double sum = 0.0;
352             for (int i = 0; i < numPoints; i++) {
353                 if (!taken[i]) {
354                     sum += minDistSquared[i];
355                     if (sum >= r) {
356                         nextPointIndex = i;
357                         break;
358                     }
359                 }
360             }
361 
362             // If it's not set to >= 0, the point wasn't found in the previous
363             // for loop, probably because distances are extremely small.  Just pick
364             // the last available point.
365             if (nextPointIndex == -1) {
366                 for (int i = numPoints - 1; i >= 0; i--) {
367                     if (!taken[i]) {
368                         nextPointIndex = i;
369                         break;
370                     }
371                 }
372             }
373 
374             // We found one.
375             if (nextPointIndex >= 0) {
376 
377                 final T p = pointList.get(nextPointIndex);
378 
379                 resultSet.add(new CentroidCluster<T> (p));
380 
381                 // Mark it as taken.
382                 taken[nextPointIndex] = true;
383 
384                 if (resultSet.size() < k) {
385                     // Now update elements of minDistSquared.  We only have to compute
386                     // the distance to the new center to do this.
387                     for (int j = 0; j < numPoints; j++) {
388                         // Only have to worry about the points still not taken.
389                         if (!taken[j]) {
390                             double d = distance(p, pointList.get(j));
391                             double d2 = d * d;
392                             if (d2 < minDistSquared[j]) {
393                                 minDistSquared[j] = d2;
394                             }
395                         }
396                     }
397                 }
398 
399             } else {
400                 // None found --
401                 // Break from the while loop to prevent
402                 // an infinite loop.
403                 break;
404             }
405         }
406 
407         return resultSet;
408     }
409 
410     /**
411      * Get a random point from the {@link Cluster} with the largest distance variance.
412      *
413      * @param clusters the {@link Cluster}s to search
414      * @return a random point from the selected cluster
415      * @throws MathIllegalStateException if clusters are all empty
416      */
417     private T getPointFromLargestVarianceCluster(final Collection<CentroidCluster<T>> clusters)
418             throws MathIllegalStateException {
419 
420         double maxVariance = Double.NEGATIVE_INFINITY;
421         Cluster<T> selected = null;
422         for (final CentroidCluster<T> cluster : clusters) {
423             if (!cluster.getPoints().isEmpty()) {
424 
425                 // compute the distance variance of the current cluster
426                 final Clusterable center = cluster.getCenter();
427                 final Variance stat = new Variance();
428                 for (final T point : cluster.getPoints()) {
429                     stat.increment(distance(point, center));
430                 }
431                 final double variance = stat.getResult();
432 
433                 // select the cluster with the largest variance
434                 if (variance > maxVariance) {
435                     maxVariance = variance;
436                     selected = cluster;
437                 }
438 
439             }
440         }
441 
442         // did we find at least one non-empty cluster ?
443         if (selected == null) {
444             throw new MathIllegalStateException(LocalizedClusteringFormats.EMPTY_CLUSTER_IN_K_MEANS);
445         }
446 
447         // extract a random point from the cluster
448         final List<T> selectedPoints = selected.getPoints();
449         return selectedPoints.remove(random.nextInt(selectedPoints.size()));
450 
451     }
452 
453     /**
454      * Get a random point from the {@link Cluster} with the largest number of points
455      *
456      * @param clusters the {@link Cluster}s to search
457      * @return a random point from the selected cluster
458      * @throws MathIllegalStateException if clusters are all empty
459      */
460     private T getPointFromLargestNumberCluster(final Collection<? extends Cluster<T>> clusters)
461             throws MathIllegalStateException {
462 
463         int maxNumber = 0;
464         Cluster<T> selected = null;
465         for (final Cluster<T> cluster : clusters) {
466 
467             // get the number of points of the current cluster
468             final int number = cluster.getPoints().size();
469 
470             // select the cluster with the largest number of points
471             if (number > maxNumber) {
472                 maxNumber = number;
473                 selected = cluster;
474             }
475 
476         }
477 
478         // did we find at least one non-empty cluster ?
479         if (selected == null) {
480             throw new MathIllegalStateException(LocalizedClusteringFormats.EMPTY_CLUSTER_IN_K_MEANS);
481         }
482 
483         // extract a random point from the cluster
484         final List<T> selectedPoints = selected.getPoints();
485         return selectedPoints.remove(random.nextInt(selectedPoints.size()));
486 
487     }
488 
489     /**
490      * Get the point farthest to its cluster center
491      *
492      * @param clusters the {@link Cluster}s to search
493      * @return point farthest to its cluster center
494      * @throws MathIllegalStateException if clusters are all empty
495      */
496     private T getFarthestPoint(final Collection<CentroidCluster<T>> clusters) throws MathIllegalStateException {
497 
498         double maxDistance = Double.NEGATIVE_INFINITY;
499         Cluster<T> selectedCluster = null;
500         int selectedPoint = -1;
501         for (final CentroidCluster<T> cluster : clusters) {
502 
503             // get the farthest point
504             final Clusterable center = cluster.getCenter();
505             final List<T> points = cluster.getPoints();
506             for (int i = 0; i < points.size(); ++i) {
507                 final double distance = distance(points.get(i), center);
508                 if (distance > maxDistance) {
509                     maxDistance     = distance;
510                     selectedCluster = cluster;
511                     selectedPoint   = i;
512                 }
513             }
514 
515         }
516 
517         // did we find at least one non-empty cluster ?
518         if (selectedCluster == null) {
519             throw new MathIllegalStateException(LocalizedClusteringFormats.EMPTY_CLUSTER_IN_K_MEANS);
520         }
521 
522         return selectedCluster.getPoints().remove(selectedPoint);
523 
524     }
525 
526     /**
527      * Returns the nearest {@link Cluster} to the given point
528      *
529      * @param clusters the {@link Cluster}s to search
530      * @param point the point to find the nearest {@link Cluster} for
531      * @return the index of the nearest {@link Cluster} to the given point
532      */
533     private int getNearestCluster(final Collection<CentroidCluster<T>> clusters, final T point) {
534         double minDistance = Double.MAX_VALUE;
535         int clusterIndex = 0;
536         int minCluster = 0;
537         for (final CentroidCluster<T> c : clusters) {
538             final double distance = distance(point, c.getCenter());
539             if (distance < minDistance) {
540                 minDistance = distance;
541                 minCluster = clusterIndex;
542             }
543             clusterIndex++;
544         }
545         return minCluster;
546     }
547 
548     /**
549      * Computes the centroid for a set of points.
550      *
551      * @param points the set of points
552      * @param dimension the point dimension
553      * @return the computed centroid for the set of points
554      */
555     private Clusterable centroidOf(final Collection<T> points, final int dimension) {
556         final double[] centroid = new double[dimension];
557         for (final T p : points) {
558             final double[] point = p.getPoint();
559             for (int i = 0; i < centroid.length; i++) {
560                 centroid[i] += point[i];
561             }
562         }
563         for (int i = 0; i < centroid.length; i++) {
564             centroid[i] /= points.size();
565         }
566         return new DoublePoint(centroid);
567     }
568 
569 }