MultiKMeansPlusPlusClusterer.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      https://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. /*
  18.  * This is not the original file distributed by the Apache Software Foundation
  19.  * It has been modified by the Hipparchus project
  20.  */

  21. package org.hipparchus.clustering;

  22. import java.util.Collection;
  23. import java.util.List;

  24. import org.hipparchus.clustering.evaluation.ClusterEvaluator;
  25. import org.hipparchus.clustering.evaluation.SumOfClusterVariances;
  26. import org.hipparchus.exception.MathIllegalArgumentException;
  27. import org.hipparchus.exception.MathIllegalStateException;

  28. /**
  29.  * A wrapper around a k-means++ clustering algorithm which performs multiple trials
  30.  * and returns the best solution.
  31.  * @param <T> type of the points to cluster
  32.  */
  33. public class MultiKMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T> {

  34.     /** The underlying k-means clusterer. */
  35.     private final KMeansPlusPlusClusterer<T> clusterer;

  36.     /** The number of trial runs. */
  37.     private final int numTrials;

  38.     /** The cluster evaluator to use. */
  39.     private final ClusterEvaluator<T> evaluator;

  40.     /** Build a clusterer.
  41.      * @param clusterer the k-means clusterer to use
  42.      * @param numTrials number of trial runs
  43.      */
  44.     public MultiKMeansPlusPlusClusterer(final KMeansPlusPlusClusterer<T> clusterer,
  45.                                         final int numTrials) {
  46.         this(clusterer, numTrials, new SumOfClusterVariances<>(clusterer.getDistanceMeasure()));
  47.     }

  48.     /** Build a clusterer.
  49.      * @param clusterer the k-means clusterer to use
  50.      * @param numTrials number of trial runs
  51.      * @param evaluator the cluster evaluator to use
  52.      */
  53.     public MultiKMeansPlusPlusClusterer(final KMeansPlusPlusClusterer<T> clusterer,
  54.                                         final int numTrials,
  55.                                         final ClusterEvaluator<T> evaluator) {
  56.         super(clusterer.getDistanceMeasure());
  57.         this.clusterer = clusterer;
  58.         this.numTrials = numTrials;
  59.         this.evaluator = evaluator;
  60.     }

  61.     /**
  62.      * Returns the embedded k-means clusterer used by this instance.
  63.      * @return the embedded clusterer
  64.      */
  65.     public KMeansPlusPlusClusterer<T> getClusterer() {
  66.         return clusterer;
  67.     }

  68.     /**
  69.      * Returns the number of trials this instance will do.
  70.      * @return the number of trials
  71.      */
  72.     public int getNumTrials() {
  73.         return numTrials;
  74.     }

  75.     /**
  76.      * Returns the {@link ClusterEvaluator} used to determine the "best" clustering.
  77.      * @return the used {@link ClusterEvaluator}
  78.      */
  79.     public ClusterEvaluator<T> getClusterEvaluator() {
  80.        return evaluator;
  81.     }

  82.     /**
  83.      * Runs the K-means++ clustering algorithm.
  84.      *
  85.      * @param points the points to cluster
  86.      * @return a list of clusters containing the points
  87.      * @throws MathIllegalArgumentException if the data points are null or the number
  88.      *   of clusters is larger than the number of data points
  89.      * @throws MathIllegalStateException if an empty cluster is encountered and the
  90.      *   underlying {@link KMeansPlusPlusClusterer} has its
  91.      *   {@link KMeansPlusPlusClusterer.EmptyClusterStrategy} is set to {@code ERROR}.
  92.      */
  93.     @Override
  94.     public List<CentroidCluster<T>> cluster(final Collection<T> points)
  95.         throws MathIllegalArgumentException, MathIllegalStateException {

  96.         // at first, we have not found any clusters list yet
  97.         List<CentroidCluster<T>> best = null;
  98.         double bestVarianceSum = Double.POSITIVE_INFINITY;

  99.         // do several clustering trials
  100.         for (int i = 0; i < numTrials; ++i) {

  101.             // compute a clusters list
  102.             List<CentroidCluster<T>> clusters = clusterer.cluster(points);

  103.             // compute the variance of the current list
  104.             final double varianceSum = evaluator.score(clusters);

  105.             if (evaluator.isBetterScore(varianceSum, bestVarianceSum)) {
  106.                 // this one is the best we have found so far, remember it
  107.                 best            = clusters;
  108.                 bestVarianceSum = varianceSum;
  109.             }

  110.         }

  111.         // return the best clusters list found
  112.         return best;

  113.     }

  114. }