KMeansPlusPlusClusterer.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      https://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. /*
  18.  * This is not the original file distributed by the Apache Software Foundation
  19.  * It has been modified by the Hipparchus project
  20.  */

  21. package org.hipparchus.clustering;

  22. import java.util.ArrayList;
  23. import java.util.Collection;
  24. import java.util.Collections;
  25. import java.util.List;

  26. import org.hipparchus.clustering.distance.DistanceMeasure;
  27. import org.hipparchus.clustering.distance.EuclideanDistance;
  28. import org.hipparchus.exception.LocalizedCoreFormats;
  29. import org.hipparchus.exception.MathIllegalArgumentException;
  30. import org.hipparchus.exception.MathIllegalStateException;
  31. import org.hipparchus.random.JDKRandomGenerator;
  32. import org.hipparchus.random.RandomGenerator;
  33. import org.hipparchus.stat.descriptive.moment.Variance;
  34. import org.hipparchus.util.MathUtils;

  35. /**
  36.  * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm.
  37.  * @param <T> type of the points to cluster
  38.  * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a>
  39.  */
  40. public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T> {

  41.     /** Strategies to use for replacing an empty cluster. */
  42.     public enum EmptyClusterStrategy {

  43.         /** Split the cluster with largest distance variance. */
  44.         LARGEST_VARIANCE,

  45.         /** Split the cluster with largest number of points. */
  46.         LARGEST_POINTS_NUMBER,

  47.         /** Create a cluster around the point farthest from its centroid. */
  48.         FARTHEST_POINT,

  49.         /** Generate an error. */
  50.         ERROR

  51.     }

  52.     /** The number of clusters. */
  53.     private final int k;

  54.     /** The maximum number of iterations. */
  55.     private final int maxIterations;

  56.     /** Random generator for choosing initial centers. */
  57.     private final RandomGenerator random;

  58.     /** Selected strategy for empty clusters. */
  59.     private final EmptyClusterStrategy emptyStrategy;

  60.     /** Build a clusterer.
  61.      * <p>
  62.      * The default strategy for handling empty clusters that may appear during
  63.      * algorithm iterations is to split the cluster with largest distance variance.
  64.      * <p>
  65.      * The euclidean distance will be used as default distance measure.
  66.      *
  67.      * @param k the number of clusters to split the data into
  68.      */
  69.     public KMeansPlusPlusClusterer(final int k) {
  70.         this(k, -1);
  71.     }

  72.     /** Build a clusterer.
  73.      * <p>
  74.      * The default strategy for handling empty clusters that may appear during
  75.      * algorithm iterations is to split the cluster with largest distance variance.
  76.      * <p>
  77.      * The euclidean distance will be used as default distance measure.
  78.      *
  79.      * @param k the number of clusters to split the data into
  80.      * @param maxIterations the maximum number of iterations to run the algorithm for.
  81.      *   If negative, no maximum will be used.
  82.      */
  83.     public KMeansPlusPlusClusterer(final int k, final int maxIterations) {
  84.         this(k, maxIterations, new EuclideanDistance());
  85.     }

  86.     /** Build a clusterer.
  87.      * <p>
  88.      * The default strategy for handling empty clusters that may appear during
  89.      * algorithm iterations is to split the cluster with largest distance variance.
  90.      *
  91.      * @param k the number of clusters to split the data into
  92.      * @param maxIterations the maximum number of iterations to run the algorithm for.
  93.      *   If negative, no maximum will be used.
  94.      * @param measure the distance measure to use
  95.      */
  96.     public KMeansPlusPlusClusterer(final int k, final int maxIterations, final DistanceMeasure measure) {
  97.         this(k, maxIterations, measure, new JDKRandomGenerator());
  98.     }

  99.     /** Build a clusterer.
  100.      * <p>
  101.      * The default strategy for handling empty clusters that may appear during
  102.      * algorithm iterations is to split the cluster with largest distance variance.
  103.      *
  104.      * @param k the number of clusters to split the data into
  105.      * @param maxIterations the maximum number of iterations to run the algorithm for.
  106.      *   If negative, no maximum will be used.
  107.      * @param measure the distance measure to use
  108.      * @param random random generator to use for choosing initial centers
  109.      */
  110.     public KMeansPlusPlusClusterer(final int k, final int maxIterations,
  111.                                    final DistanceMeasure measure,
  112.                                    final RandomGenerator random) {
  113.         this(k, maxIterations, measure, random, EmptyClusterStrategy.LARGEST_VARIANCE);
  114.     }

  115.     /** Build a clusterer.
  116.      *
  117.      * @param k the number of clusters to split the data into
  118.      * @param maxIterations the maximum number of iterations to run the algorithm for.
  119.      *   If negative, no maximum will be used.
  120.      * @param measure the distance measure to use
  121.      * @param random random generator to use for choosing initial centers
  122.      * @param emptyStrategy strategy to use for handling empty clusters that
  123.      * may appear during algorithm iterations
  124.      */
  125.     public KMeansPlusPlusClusterer(final int k, final int maxIterations,
  126.                                    final DistanceMeasure measure,
  127.                                    final RandomGenerator random,
  128.                                    final EmptyClusterStrategy emptyStrategy) {
  129.         super(measure);
  130.         this.k             = k;
  131.         this.maxIterations = maxIterations;
  132.         this.random        = random;
  133.         this.emptyStrategy = emptyStrategy;
  134.     }

  135.     /**
  136.      * Return the number of clusters this instance will use.
  137.      * @return the number of clusters
  138.      */
  139.     public int getK() {
  140.         return k;
  141.     }

  142.     /**
  143.      * Returns the maximum number of iterations this instance will use.
  144.      * @return the maximum number of iterations, or -1 if no maximum is set
  145.      */
  146.     public int getMaxIterations() {
  147.         return maxIterations;
  148.     }

  149.     /**
  150.      * Returns the random generator this instance will use.
  151.      * @return the random generator
  152.      */
  153.     public RandomGenerator getRandomGenerator() {
  154.         return random;
  155.     }

  156.     /**
  157.      * Returns the {@link EmptyClusterStrategy} used by this instance.
  158.      * @return the {@link EmptyClusterStrategy}
  159.      */
  160.     public EmptyClusterStrategy getEmptyClusterStrategy() {
  161.         return emptyStrategy;
  162.     }

  163.     /**
  164.      * Runs the K-means++ clustering algorithm.
  165.      *
  166.      * @param points the points to cluster
  167.      * @return a list of clusters containing the points
  168.      * @throws MathIllegalArgumentException if the data points are null or the number
  169.      *     of clusters is larger than the number of data points
  170.      * @throws MathIllegalStateException if an empty cluster is encountered and the
  171.      * {@link #emptyStrategy} is set to {@code ERROR}
  172.      */
  173.     @Override
  174.     public List<CentroidCluster<T>> cluster(final Collection<T> points)
  175.         throws MathIllegalArgumentException, MathIllegalStateException {

  176.         // sanity checks
  177.         MathUtils.checkNotNull(points);

  178.         // number of clusters has to be smaller or equal the number of data points
  179.         if (points.size() < k) {
  180.             throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL_BOUND_EXCLUDED,
  181.                                                    points.size(), k);
  182.         }

  183.         // create the initial clusters
  184.         List<CentroidCluster<T>> clusters = chooseInitialCenters(points);

  185.         // create an array containing the latest assignment of a point to a cluster
  186.         // no need to initialize the array, as it will be filled with the first assignment
  187.         int[] assignments = new int[points.size()];
  188.         assignPointsToClusters(clusters, points, assignments);

  189.         // iterate through updating the centers until we're done
  190.         final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
  191.         for (int count = 0; count < max; count++) {
  192.             boolean emptyCluster = false;
  193.             List<CentroidCluster<T>> newClusters = new ArrayList<>();
  194.             for (final CentroidCluster<T> cluster : clusters) {
  195.                 final Clusterable newCenter;
  196.                 if (cluster.getPoints().isEmpty()) {
  197.                     switch (emptyStrategy) {
  198.                         case LARGEST_VARIANCE :
  199.                             newCenter = getPointFromLargestVarianceCluster(clusters);
  200.                             break;
  201.                         case LARGEST_POINTS_NUMBER :
  202.                             newCenter = getPointFromLargestNumberCluster(clusters);
  203.                             break;
  204.                         case FARTHEST_POINT :
  205.                             newCenter = getFarthestPoint(clusters);
  206.                             break;
  207.                         default :
  208.                             throw new MathIllegalStateException(LocalizedClusteringFormats.EMPTY_CLUSTER_IN_K_MEANS);
  209.                     }
  210.                     emptyCluster = true;
  211.                 } else {
  212.                     newCenter = centroidOf(cluster.getPoints(), cluster.getCenter().getPoint().length);
  213.                 }
  214.                 newClusters.add(new CentroidCluster<>(newCenter));
  215.             }
  216.             int changes = assignPointsToClusters(newClusters, points, assignments);
  217.             clusters = newClusters;

  218.             // if there were no more changes in the point-to-cluster assignment
  219.             // and there are no empty clusters left, return the current clusters
  220.             if (changes == 0 && !emptyCluster) {
  221.                 return clusters;
  222.             }
  223.         }
  224.         return clusters;
  225.     }

  226.     /**
  227.      * Adds the given points to the closest {@link Cluster}.
  228.      *
  229.      * @param clusters the {@link Cluster}s to add the points to
  230.      * @param points the points to add to the given {@link Cluster}s
  231.      * @param assignments points assignments to clusters
  232.      * @return the number of points assigned to different clusters as the iteration before
  233.      */
  234.     private int assignPointsToClusters(final List<CentroidCluster<T>> clusters,
  235.                                        final Collection<T> points,
  236.                                        final int[] assignments) {
  237.         int assignedDifferently = 0;
  238.         int pointIndex = 0;
  239.         for (final T p : points) {
  240.             int clusterIndex = getNearestCluster(clusters, p);
  241.             if (clusterIndex != assignments[pointIndex]) {
  242.                 assignedDifferently++;
  243.             }

  244.             CentroidCluster<T> cluster = clusters.get(clusterIndex);
  245.             cluster.addPoint(p);
  246.             assignments[pointIndex++] = clusterIndex;
  247.         }

  248.         return assignedDifferently;
  249.     }

  250.     /**
  251.      * Use K-means++ to choose the initial centers.
  252.      *
  253.      * @param points the points to choose the initial centers from
  254.      * @return the initial centers
  255.      */
  256.     private List<CentroidCluster<T>> chooseInitialCenters(final Collection<T> points) {

  257.         // Convert to list for indexed access. Make it unmodifiable, since removal of items
  258.         // would screw up the logic of this method.
  259.         final List<T> pointList = Collections.unmodifiableList(new ArrayList<>(points));

  260.         // The number of points in the list.
  261.         final int numPoints = pointList.size();

  262.         // Set the corresponding element in this array to indicate when
  263.         // elements of pointList are no longer available.
  264.         final boolean[] taken = new boolean[numPoints];

  265.         // The resulting list of initial centers.
  266.         final List<CentroidCluster<T>> resultSet = new ArrayList<>();

  267.         // Choose one center uniformly at random from among the data points.
  268.         final int firstPointIndex = random.nextInt(numPoints);

  269.         final T firstPoint = pointList.get(firstPointIndex);

  270.         resultSet.add(new CentroidCluster<>(firstPoint));

  271.         // Must mark it as taken
  272.         taken[firstPointIndex] = true;

  273.         // To keep track of the minimum distance squared of elements of
  274.         // pointList to elements of resultSet.
  275.         final double[] minDistSquared = new double[numPoints];

  276.         // Initialize the elements.  Since the only point in resultSet is firstPoint,
  277.         // this is very easy.
  278.         for (int i = 0; i < numPoints; i++) {
  279.             if (i != firstPointIndex) { // That point isn't considered
  280.                 double d = distance(firstPoint, pointList.get(i));
  281.                 minDistSquared[i] = d*d;
  282.             }
  283.         }

  284.         while (resultSet.size() < k) {

  285.             // Sum up the squared distances for the points in pointList not
  286.             // already taken.
  287.             double distSqSum = 0.0;

  288.             for (int i = 0; i < numPoints; i++) {
  289.                 if (!taken[i]) {
  290.                     distSqSum += minDistSquared[i];
  291.                 }
  292.             }

  293.             // Add one new data point as a center. Each point x is chosen with
  294.             // probability proportional to D(x)2
  295.             final double r = random.nextDouble() * distSqSum;

  296.             // The index of the next point to be added to the resultSet.
  297.             int nextPointIndex = -1;

  298.             // Sum through the squared min distances again, stopping when
  299.             // sum >= r.
  300.             double sum = 0.0;
  301.             for (int i = 0; i < numPoints; i++) {
  302.                 if (!taken[i]) {
  303.                     sum += minDistSquared[i];
  304.                     if (sum >= r) {
  305.                         nextPointIndex = i;
  306.                         break;
  307.                     }
  308.                 }
  309.             }

  310.             // If it's not set to >= 0, the point wasn't found in the previous
  311.             // for loop, probably because distances are extremely small.  Just pick
  312.             // the last available point.
  313.             if (nextPointIndex == -1) {
  314.                 for (int i = numPoints - 1; i >= 0; i--) {
  315.                     if (!taken[i]) {
  316.                         nextPointIndex = i;
  317.                         break;
  318.                     }
  319.                 }
  320.             }

  321.             // We found one.
  322.             if (nextPointIndex >= 0) {

  323.                 final T p = pointList.get(nextPointIndex);

  324.                 resultSet.add(new CentroidCluster<>(p));

  325.                 // Mark it as taken.
  326.                 taken[nextPointIndex] = true;

  327.                 if (resultSet.size() < k) {
  328.                     // Now update elements of minDistSquared.  We only have to compute
  329.                     // the distance to the new center to do this.
  330.                     for (int j = 0; j < numPoints; j++) {
  331.                         // Only have to worry about the points still not taken.
  332.                         if (!taken[j]) {
  333.                             double d = distance(p, pointList.get(j));
  334.                             double d2 = d * d;
  335.                             if (d2 < minDistSquared[j]) {
  336.                                 minDistSquared[j] = d2;
  337.                             }
  338.                         }
  339.                     }
  340.                 }

  341.             } else {
  342.                 // None found --
  343.                 // Break from the while loop to prevent
  344.                 // an infinite loop.
  345.                 break;
  346.             }
  347.         }

  348.         return resultSet;
  349.     }

  350.     /**
  351.      * Get a random point from the {@link Cluster} with the largest distance variance.
  352.      *
  353.      * @param clusters the {@link Cluster}s to search
  354.      * @return a random point from the selected cluster
  355.      * @throws MathIllegalStateException if clusters are all empty
  356.      */
  357.     private T getPointFromLargestVarianceCluster(final Collection<CentroidCluster<T>> clusters)
  358.             throws MathIllegalStateException {

  359.         double maxVariance = Double.NEGATIVE_INFINITY;
  360.         Cluster<T> selected = null;
  361.         for (final CentroidCluster<T> cluster : clusters) {
  362.             if (!cluster.getPoints().isEmpty()) {

  363.                 // compute the distance variance of the current cluster
  364.                 final Clusterable center = cluster.getCenter();
  365.                 final Variance stat = new Variance();
  366.                 for (final T point : cluster.getPoints()) {
  367.                     stat.increment(distance(point, center));
  368.                 }
  369.                 final double variance = stat.getResult();

  370.                 // select the cluster with the largest variance
  371.                 if (variance > maxVariance) {
  372.                     maxVariance = variance;
  373.                     selected = cluster;
  374.                 }

  375.             }
  376.         }

  377.         // did we find at least one non-empty cluster ?
  378.         if (selected == null) {
  379.             throw new MathIllegalStateException(LocalizedClusteringFormats.EMPTY_CLUSTER_IN_K_MEANS);
  380.         }

  381.         // extract a random point from the cluster
  382.         final List<T> selectedPoints = selected.getPoints();
  383.         return selectedPoints.remove(random.nextInt(selectedPoints.size()));

  384.     }

  385.     /**
  386.      * Get a random point from the {@link Cluster} with the largest number of points
  387.      *
  388.      * @param clusters the {@link Cluster}s to search
  389.      * @return a random point from the selected cluster
  390.      * @throws MathIllegalStateException if clusters are all empty
  391.      */
  392.     private T getPointFromLargestNumberCluster(final Collection<? extends Cluster<T>> clusters)
  393.             throws MathIllegalStateException {

  394.         int maxNumber = 0;
  395.         Cluster<T> selected = null;
  396.         for (final Cluster<T> cluster : clusters) {

  397.             // get the number of points of the current cluster
  398.             final int number = cluster.getPoints().size();

  399.             // select the cluster with the largest number of points
  400.             if (number > maxNumber) {
  401.                 maxNumber = number;
  402.                 selected = cluster;
  403.             }

  404.         }

  405.         // did we find at least one non-empty cluster ?
  406.         if (selected == null) {
  407.             throw new MathIllegalStateException(LocalizedClusteringFormats.EMPTY_CLUSTER_IN_K_MEANS);
  408.         }

  409.         // extract a random point from the cluster
  410.         final List<T> selectedPoints = selected.getPoints();
  411.         return selectedPoints.remove(random.nextInt(selectedPoints.size()));

  412.     }

  413.     /**
  414.      * Get the point farthest to its cluster center
  415.      *
  416.      * @param clusters the {@link Cluster}s to search
  417.      * @return point farthest to its cluster center
  418.      * @throws MathIllegalStateException if clusters are all empty
  419.      */
  420.     private T getFarthestPoint(final Collection<CentroidCluster<T>> clusters) throws MathIllegalStateException {

  421.         double maxDistance = Double.NEGATIVE_INFINITY;
  422.         Cluster<T> selectedCluster = null;
  423.         int selectedPoint = -1;
  424.         for (final CentroidCluster<T> cluster : clusters) {

  425.             // get the farthest point
  426.             final Clusterable center = cluster.getCenter();
  427.             final List<T> points = cluster.getPoints();
  428.             for (int i = 0; i < points.size(); ++i) {
  429.                 final double distance = distance(points.get(i), center);
  430.                 if (distance > maxDistance) {
  431.                     maxDistance     = distance;
  432.                     selectedCluster = cluster;
  433.                     selectedPoint   = i;
  434.                 }
  435.             }

  436.         }

  437.         // did we find at least one non-empty cluster ?
  438.         if (selectedCluster == null) {
  439.             throw new MathIllegalStateException(LocalizedClusteringFormats.EMPTY_CLUSTER_IN_K_MEANS);
  440.         }

  441.         return selectedCluster.getPoints().remove(selectedPoint);

  442.     }

  443.     /**
  444.      * Returns the nearest {@link Cluster} to the given point
  445.      *
  446.      * @param clusters the {@link Cluster}s to search
  447.      * @param point the point to find the nearest {@link Cluster} for
  448.      * @return the index of the nearest {@link Cluster} to the given point
  449.      */
  450.     private int getNearestCluster(final Collection<CentroidCluster<T>> clusters, final T point) {
  451.         double minDistance = Double.MAX_VALUE;
  452.         int clusterIndex = 0;
  453.         int minCluster = 0;
  454.         for (final CentroidCluster<T> c : clusters) {
  455.             final double distance = distance(point, c.getCenter());
  456.             if (distance < minDistance) {
  457.                 minDistance = distance;
  458.                 minCluster = clusterIndex;
  459.             }
  460.             clusterIndex++;
  461.         }
  462.         return minCluster;
  463.     }

  464.     /**
  465.      * Computes the centroid for a set of points.
  466.      *
  467.      * @param points the set of points
  468.      * @param dimension the point dimension
  469.      * @return the computed centroid for the set of points
  470.      */
  471.     private Clusterable centroidOf(final Collection<T> points, final int dimension) {
  472.         final double[] centroid = new double[dimension];
  473.         for (final T p : points) {
  474.             final double[] point = p.getPoint();
  475.             for (int i = 0; i < centroid.length; i++) {
  476.                 centroid[i] += point[i];
  477.             }
  478.         }
  479.         for (int i = 0; i < centroid.length; i++) {
  480.             centroid[i] /= points.size();
  481.         }
  482.         return new DoublePoint(centroid);
  483.     }

  484. }