1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * https://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 /*
19 * This is not the original file distributed by the Apache Software Foundation
20 * It has been modified by the Hipparchus project
21 */
22
23 package org.hipparchus.clustering;
24
25 import java.util.Collection;
26 import java.util.List;
27
28 import org.hipparchus.clustering.evaluation.ClusterEvaluator;
29 import org.hipparchus.clustering.evaluation.SumOfClusterVariances;
30 import org.hipparchus.exception.MathIllegalArgumentException;
31 import org.hipparchus.exception.MathIllegalStateException;
32
33 /**
34 * A wrapper around a k-means++ clustering algorithm which performs multiple trials
35 * and returns the best solution.
36 * @param <T> type of the points to cluster
37 */
38 public class MultiKMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T> {
39
40 /** The underlying k-means clusterer. */
41 private final KMeansPlusPlusClusterer<T> clusterer;
42
43 /** The number of trial runs. */
44 private final int numTrials;
45
46 /** The cluster evaluator to use. */
47 private final ClusterEvaluator<T> evaluator;
48
49 /** Build a clusterer.
50 * @param clusterer the k-means clusterer to use
51 * @param numTrials number of trial runs
52 */
53 public MultiKMeansPlusPlusClusterer(final KMeansPlusPlusClusterer<T> clusterer,
54 final int numTrials) {
55 this(clusterer, numTrials, new SumOfClusterVariances<>(clusterer.getDistanceMeasure()));
56 }
57
58 /** Build a clusterer.
59 * @param clusterer the k-means clusterer to use
60 * @param numTrials number of trial runs
61 * @param evaluator the cluster evaluator to use
62 */
63 public MultiKMeansPlusPlusClusterer(final KMeansPlusPlusClusterer<T> clusterer,
64 final int numTrials,
65 final ClusterEvaluator<T> evaluator) {
66 super(clusterer.getDistanceMeasure());
67 this.clusterer = clusterer;
68 this.numTrials = numTrials;
69 this.evaluator = evaluator;
70 }
71
72 /**
73 * Returns the embedded k-means clusterer used by this instance.
74 * @return the embedded clusterer
75 */
76 public KMeansPlusPlusClusterer<T> getClusterer() {
77 return clusterer;
78 }
79
80 /**
81 * Returns the number of trials this instance will do.
82 * @return the number of trials
83 */
84 public int getNumTrials() {
85 return numTrials;
86 }
87
88 /**
89 * Returns the {@link ClusterEvaluator} used to determine the "best" clustering.
90 * @return the used {@link ClusterEvaluator}
91 */
92 public ClusterEvaluator<T> getClusterEvaluator() {
93 return evaluator;
94 }
95
96 /**
97 * Runs the K-means++ clustering algorithm.
98 *
99 * @param points the points to cluster
100 * @return a list of clusters containing the points
101 * @throws MathIllegalArgumentException if the data points are null or the number
102 * of clusters is larger than the number of data points
103 * @throws MathIllegalStateException if an empty cluster is encountered and the
104 * underlying {@link KMeansPlusPlusClusterer} has its
105 * {@link KMeansPlusPlusClusterer.EmptyClusterStrategy} is set to {@code ERROR}.
106 */
107 @Override
108 public List<CentroidCluster<T>> cluster(final Collection<T> points)
109 throws MathIllegalArgumentException, MathIllegalStateException {
110
111 // at first, we have not found any clusters list yet
112 List<CentroidCluster<T>> best = null;
113 double bestVarianceSum = Double.POSITIVE_INFINITY;
114
115 // do several clustering trials
116 for (int i = 0; i < numTrials; ++i) {
117
118 // compute a clusters list
119 List<CentroidCluster<T>> clusters = clusterer.cluster(points);
120
121 // compute the variance of the current list
122 final double varianceSum = evaluator.score(clusters);
123
124 if (evaluator.isBetterScore(varianceSum, bestVarianceSum)) {
125 // this one is the best we have found so far, remember it
126 best = clusters;
127 bestVarianceSum = varianceSum;
128 }
129
130 }
131
132 // return the best clusters list found
133 return best;
134
135 }
136
137 }