KMeansPlusPlusClusterer.java
- /*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- /*
- * This is not the original file distributed by the Apache Software Foundation
- * It has been modified by the Hipparchus project
- */
- package org.hipparchus.clustering;
- import java.util.ArrayList;
- import java.util.Collection;
- import java.util.Collections;
- import java.util.List;
- import org.hipparchus.clustering.distance.DistanceMeasure;
- import org.hipparchus.clustering.distance.EuclideanDistance;
- import org.hipparchus.exception.LocalizedCoreFormats;
- import org.hipparchus.exception.MathIllegalArgumentException;
- import org.hipparchus.exception.MathIllegalStateException;
- import org.hipparchus.random.JDKRandomGenerator;
- import org.hipparchus.random.RandomGenerator;
- import org.hipparchus.stat.descriptive.moment.Variance;
- import org.hipparchus.util.MathUtils;
- /**
- * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm.
- * @param <T> type of the points to cluster
- * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a>
- */
- public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T> {
- /** Strategies to use for replacing an empty cluster. */
- public enum EmptyClusterStrategy {
- /** Split the cluster with largest distance variance. */
- LARGEST_VARIANCE,
- /** Split the cluster with largest number of points. */
- LARGEST_POINTS_NUMBER,
- /** Create a cluster around the point farthest from its centroid. */
- FARTHEST_POINT,
- /** Generate an error. */
- ERROR
- }
- /** The number of clusters. */
- private final int k;
- /** The maximum number of iterations. */
- private final int maxIterations;
- /** Random generator for choosing initial centers. */
- private final RandomGenerator random;
- /** Selected strategy for empty clusters. */
- private final EmptyClusterStrategy emptyStrategy;
- /** Build a clusterer.
- * <p>
- * The default strategy for handling empty clusters that may appear during
- * algorithm iterations is to split the cluster with largest distance variance.
- * <p>
- * The euclidean distance will be used as default distance measure.
- *
- * @param k the number of clusters to split the data into
- */
- public KMeansPlusPlusClusterer(final int k) {
- this(k, -1);
- }
- /** Build a clusterer.
- * <p>
- * The default strategy for handling empty clusters that may appear during
- * algorithm iterations is to split the cluster with largest distance variance.
- * <p>
- * The euclidean distance will be used as default distance measure.
- *
- * @param k the number of clusters to split the data into
- * @param maxIterations the maximum number of iterations to run the algorithm for.
- * If negative, no maximum will be used.
- */
- public KMeansPlusPlusClusterer(final int k, final int maxIterations) {
- this(k, maxIterations, new EuclideanDistance());
- }
- /** Build a clusterer.
- * <p>
- * The default strategy for handling empty clusters that may appear during
- * algorithm iterations is to split the cluster with largest distance variance.
- *
- * @param k the number of clusters to split the data into
- * @param maxIterations the maximum number of iterations to run the algorithm for.
- * If negative, no maximum will be used.
- * @param measure the distance measure to use
- */
- public KMeansPlusPlusClusterer(final int k, final int maxIterations, final DistanceMeasure measure) {
- this(k, maxIterations, measure, new JDKRandomGenerator());
- }
- /** Build a clusterer.
- * <p>
- * The default strategy for handling empty clusters that may appear during
- * algorithm iterations is to split the cluster with largest distance variance.
- *
- * @param k the number of clusters to split the data into
- * @param maxIterations the maximum number of iterations to run the algorithm for.
- * If negative, no maximum will be used.
- * @param measure the distance measure to use
- * @param random random generator to use for choosing initial centers
- */
- public KMeansPlusPlusClusterer(final int k, final int maxIterations,
- final DistanceMeasure measure,
- final RandomGenerator random) {
- this(k, maxIterations, measure, random, EmptyClusterStrategy.LARGEST_VARIANCE);
- }
- /** Build a clusterer.
- *
- * @param k the number of clusters to split the data into
- * @param maxIterations the maximum number of iterations to run the algorithm for.
- * If negative, no maximum will be used.
- * @param measure the distance measure to use
- * @param random random generator to use for choosing initial centers
- * @param emptyStrategy strategy to use for handling empty clusters that
- * may appear during algorithm iterations
- */
- public KMeansPlusPlusClusterer(final int k, final int maxIterations,
- final DistanceMeasure measure,
- final RandomGenerator random,
- final EmptyClusterStrategy emptyStrategy) {
- super(measure);
- this.k = k;
- this.maxIterations = maxIterations;
- this.random = random;
- this.emptyStrategy = emptyStrategy;
- }
- /**
- * Return the number of clusters this instance will use.
- * @return the number of clusters
- */
- public int getK() {
- return k;
- }
- /**
- * Returns the maximum number of iterations this instance will use.
- * @return the maximum number of iterations, or -1 if no maximum is set
- */
- public int getMaxIterations() {
- return maxIterations;
- }
- /**
- * Returns the random generator this instance will use.
- * @return the random generator
- */
- public RandomGenerator getRandomGenerator() {
- return random;
- }
- /**
- * Returns the {@link EmptyClusterStrategy} used by this instance.
- * @return the {@link EmptyClusterStrategy}
- */
- public EmptyClusterStrategy getEmptyClusterStrategy() {
- return emptyStrategy;
- }
- /**
- * Runs the K-means++ clustering algorithm.
- *
- * @param points the points to cluster
- * @return a list of clusters containing the points
- * @throws MathIllegalArgumentException if the data points are null or the number
- * of clusters is larger than the number of data points
- * @throws MathIllegalStateException if an empty cluster is encountered and the
- * {@link #emptyStrategy} is set to {@code ERROR}
- */
- @Override
- public List<CentroidCluster<T>> cluster(final Collection<T> points)
- throws MathIllegalArgumentException, MathIllegalStateException {
- // sanity checks
- MathUtils.checkNotNull(points);
- // number of clusters has to be smaller or equal the number of data points
- if (points.size() < k) {
- throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL_BOUND_EXCLUDED,
- points.size(), k);
- }
- // create the initial clusters
- List<CentroidCluster<T>> clusters = chooseInitialCenters(points);
- // create an array containing the latest assignment of a point to a cluster
- // no need to initialize the array, as it will be filled with the first assignment
- int[] assignments = new int[points.size()];
- assignPointsToClusters(clusters, points, assignments);
- // iterate through updating the centers until we're done
- final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
- for (int count = 0; count < max; count++) {
- boolean emptyCluster = false;
- List<CentroidCluster<T>> newClusters = new ArrayList<>();
- for (final CentroidCluster<T> cluster : clusters) {
- final Clusterable newCenter;
- if (cluster.getPoints().isEmpty()) {
- switch (emptyStrategy) {
- case LARGEST_VARIANCE :
- newCenter = getPointFromLargestVarianceCluster(clusters);
- break;
- case LARGEST_POINTS_NUMBER :
- newCenter = getPointFromLargestNumberCluster(clusters);
- break;
- case FARTHEST_POINT :
- newCenter = getFarthestPoint(clusters);
- break;
- default :
- throw new MathIllegalStateException(LocalizedClusteringFormats.EMPTY_CLUSTER_IN_K_MEANS);
- }
- emptyCluster = true;
- } else {
- newCenter = centroidOf(cluster.getPoints(), cluster.getCenter().getPoint().length);
- }
- newClusters.add(new CentroidCluster<>(newCenter));
- }
- int changes = assignPointsToClusters(newClusters, points, assignments);
- clusters = newClusters;
- // if there were no more changes in the point-to-cluster assignment
- // and there are no empty clusters left, return the current clusters
- if (changes == 0 && !emptyCluster) {
- return clusters;
- }
- }
- return clusters;
- }
- /**
- * Adds the given points to the closest {@link Cluster}.
- *
- * @param clusters the {@link Cluster}s to add the points to
- * @param points the points to add to the given {@link Cluster}s
- * @param assignments points assignments to clusters
- * @return the number of points assigned to different clusters as the iteration before
- */
- private int assignPointsToClusters(final List<CentroidCluster<T>> clusters,
- final Collection<T> points,
- final int[] assignments) {
- int assignedDifferently = 0;
- int pointIndex = 0;
- for (final T p : points) {
- int clusterIndex = getNearestCluster(clusters, p);
- if (clusterIndex != assignments[pointIndex]) {
- assignedDifferently++;
- }
- CentroidCluster<T> cluster = clusters.get(clusterIndex);
- cluster.addPoint(p);
- assignments[pointIndex++] = clusterIndex;
- }
- return assignedDifferently;
- }
- /**
- * Use K-means++ to choose the initial centers.
- *
- * @param points the points to choose the initial centers from
- * @return the initial centers
- */
- private List<CentroidCluster<T>> chooseInitialCenters(final Collection<T> points) {
- // Convert to list for indexed access. Make it unmodifiable, since removal of items
- // would screw up the logic of this method.
- final List<T> pointList = Collections.unmodifiableList(new ArrayList<>(points));
- // The number of points in the list.
- final int numPoints = pointList.size();
- // Set the corresponding element in this array to indicate when
- // elements of pointList are no longer available.
- final boolean[] taken = new boolean[numPoints];
- // The resulting list of initial centers.
- final List<CentroidCluster<T>> resultSet = new ArrayList<>();
- // Choose one center uniformly at random from among the data points.
- final int firstPointIndex = random.nextInt(numPoints);
- final T firstPoint = pointList.get(firstPointIndex);
- resultSet.add(new CentroidCluster<>(firstPoint));
- // Must mark it as taken
- taken[firstPointIndex] = true;
- // To keep track of the minimum distance squared of elements of
- // pointList to elements of resultSet.
- final double[] minDistSquared = new double[numPoints];
- // Initialize the elements. Since the only point in resultSet is firstPoint,
- // this is very easy.
- for (int i = 0; i < numPoints; i++) {
- if (i != firstPointIndex) { // That point isn't considered
- double d = distance(firstPoint, pointList.get(i));
- minDistSquared[i] = d*d;
- }
- }
- while (resultSet.size() < k) {
- // Sum up the squared distances for the points in pointList not
- // already taken.
- double distSqSum = 0.0;
- for (int i = 0; i < numPoints; i++) {
- if (!taken[i]) {
- distSqSum += minDistSquared[i];
- }
- }
- // Add one new data point as a center. Each point x is chosen with
- // probability proportional to D(x)2
- final double r = random.nextDouble() * distSqSum;
- // The index of the next point to be added to the resultSet.
- int nextPointIndex = -1;
- // Sum through the squared min distances again, stopping when
- // sum >= r.
- double sum = 0.0;
- for (int i = 0; i < numPoints; i++) {
- if (!taken[i]) {
- sum += minDistSquared[i];
- if (sum >= r) {
- nextPointIndex = i;
- break;
- }
- }
- }
- // If it's not set to >= 0, the point wasn't found in the previous
- // for loop, probably because distances are extremely small. Just pick
- // the last available point.
- if (nextPointIndex == -1) {
- for (int i = numPoints - 1; i >= 0; i--) {
- if (!taken[i]) {
- nextPointIndex = i;
- break;
- }
- }
- }
- // We found one.
- if (nextPointIndex >= 0) {
- final T p = pointList.get(nextPointIndex);
- resultSet.add(new CentroidCluster<>(p));
- // Mark it as taken.
- taken[nextPointIndex] = true;
- if (resultSet.size() < k) {
- // Now update elements of minDistSquared. We only have to compute
- // the distance to the new center to do this.
- for (int j = 0; j < numPoints; j++) {
- // Only have to worry about the points still not taken.
- if (!taken[j]) {
- double d = distance(p, pointList.get(j));
- double d2 = d * d;
- if (d2 < minDistSquared[j]) {
- minDistSquared[j] = d2;
- }
- }
- }
- }
- } else {
- // None found --
- // Break from the while loop to prevent
- // an infinite loop.
- break;
- }
- }
- return resultSet;
- }
- /**
- * Get a random point from the {@link Cluster} with the largest distance variance.
- *
- * @param clusters the {@link Cluster}s to search
- * @return a random point from the selected cluster
- * @throws MathIllegalStateException if clusters are all empty
- */
- private T getPointFromLargestVarianceCluster(final Collection<CentroidCluster<T>> clusters)
- throws MathIllegalStateException {
- double maxVariance = Double.NEGATIVE_INFINITY;
- Cluster<T> selected = null;
- for (final CentroidCluster<T> cluster : clusters) {
- if (!cluster.getPoints().isEmpty()) {
- // compute the distance variance of the current cluster
- final Clusterable center = cluster.getCenter();
- final Variance stat = new Variance();
- for (final T point : cluster.getPoints()) {
- stat.increment(distance(point, center));
- }
- final double variance = stat.getResult();
- // select the cluster with the largest variance
- if (variance > maxVariance) {
- maxVariance = variance;
- selected = cluster;
- }
- }
- }
- // did we find at least one non-empty cluster ?
- if (selected == null) {
- throw new MathIllegalStateException(LocalizedClusteringFormats.EMPTY_CLUSTER_IN_K_MEANS);
- }
- // extract a random point from the cluster
- final List<T> selectedPoints = selected.getPoints();
- return selectedPoints.remove(random.nextInt(selectedPoints.size()));
- }
- /**
- * Get a random point from the {@link Cluster} with the largest number of points
- *
- * @param clusters the {@link Cluster}s to search
- * @return a random point from the selected cluster
- * @throws MathIllegalStateException if clusters are all empty
- */
- private T getPointFromLargestNumberCluster(final Collection<? extends Cluster<T>> clusters)
- throws MathIllegalStateException {
- int maxNumber = 0;
- Cluster<T> selected = null;
- for (final Cluster<T> cluster : clusters) {
- // get the number of points of the current cluster
- final int number = cluster.getPoints().size();
- // select the cluster with the largest number of points
- if (number > maxNumber) {
- maxNumber = number;
- selected = cluster;
- }
- }
- // did we find at least one non-empty cluster ?
- if (selected == null) {
- throw new MathIllegalStateException(LocalizedClusteringFormats.EMPTY_CLUSTER_IN_K_MEANS);
- }
- // extract a random point from the cluster
- final List<T> selectedPoints = selected.getPoints();
- return selectedPoints.remove(random.nextInt(selectedPoints.size()));
- }
- /**
- * Get the point farthest to its cluster center
- *
- * @param clusters the {@link Cluster}s to search
- * @return point farthest to its cluster center
- * @throws MathIllegalStateException if clusters are all empty
- */
- private T getFarthestPoint(final Collection<CentroidCluster<T>> clusters) throws MathIllegalStateException {
- double maxDistance = Double.NEGATIVE_INFINITY;
- Cluster<T> selectedCluster = null;
- int selectedPoint = -1;
- for (final CentroidCluster<T> cluster : clusters) {
- // get the farthest point
- final Clusterable center = cluster.getCenter();
- final List<T> points = cluster.getPoints();
- for (int i = 0; i < points.size(); ++i) {
- final double distance = distance(points.get(i), center);
- if (distance > maxDistance) {
- maxDistance = distance;
- selectedCluster = cluster;
- selectedPoint = i;
- }
- }
- }
- // did we find at least one non-empty cluster ?
- if (selectedCluster == null) {
- throw new MathIllegalStateException(LocalizedClusteringFormats.EMPTY_CLUSTER_IN_K_MEANS);
- }
- return selectedCluster.getPoints().remove(selectedPoint);
- }
- /**
- * Returns the nearest {@link Cluster} to the given point
- *
- * @param clusters the {@link Cluster}s to search
- * @param point the point to find the nearest {@link Cluster} for
- * @return the index of the nearest {@link Cluster} to the given point
- */
- private int getNearestCluster(final Collection<CentroidCluster<T>> clusters, final T point) {
- double minDistance = Double.MAX_VALUE;
- int clusterIndex = 0;
- int minCluster = 0;
- for (final CentroidCluster<T> c : clusters) {
- final double distance = distance(point, c.getCenter());
- if (distance < minDistance) {
- minDistance = distance;
- minCluster = clusterIndex;
- }
- clusterIndex++;
- }
- return minCluster;
- }
- /**
- * Computes the centroid for a set of points.
- *
- * @param points the set of points
- * @param dimension the point dimension
- * @return the computed centroid for the set of points
- */
- private Clusterable centroidOf(final Collection<T> points, final int dimension) {
- final double[] centroid = new double[dimension];
- for (final T p : points) {
- final double[] point = p.getPoint();
- for (int i = 0; i < centroid.length; i++) {
- centroid[i] += point[i];
- }
- }
- for (int i = 0; i < centroid.length; i++) {
- centroid[i] /= points.size();
- }
- return new DoublePoint(centroid);
- }
- }