1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22 package org.hipparchus.samples;
23
24 import java.awt.Color;
25 import java.awt.Dimension;
26 import java.awt.Graphics;
27 import java.awt.Graphics2D;
28 import java.awt.GridBagConstraints;
29 import java.awt.GridBagLayout;
30 import java.awt.Insets;
31 import java.awt.RenderingHints;
32 import java.awt.Shape;
33 import java.awt.geom.Ellipse2D;
34 import java.util.ArrayList;
35 import java.util.Arrays;
36 import java.util.Collections;
37 import java.util.List;
38
39 import javax.swing.JComponent;
40 import javax.swing.JLabel;
41
42 import org.hipparchus.clustering.CentroidCluster;
43 import org.hipparchus.clustering.Cluster;
44 import org.hipparchus.clustering.Clusterable;
45 import org.hipparchus.clustering.Clusterer;
46 import org.hipparchus.clustering.DBSCANClusterer;
47 import org.hipparchus.clustering.DoublePoint;
48 import org.hipparchus.clustering.FuzzyKMeansClusterer;
49 import org.hipparchus.clustering.KMeansPlusPlusClusterer;
50 import org.hipparchus.geometry.euclidean.twod.Vector2D;
51 import org.hipparchus.random.RandomAdaptor;
52 import org.hipparchus.random.RandomDataGenerator;
53 import org.hipparchus.random.RandomGenerator;
54 import org.hipparchus.random.SobolSequenceGenerator;
55 import org.hipparchus.random.Well19937c;
56 import org.hipparchus.samples.ExampleUtils.ExampleFrame;
57 import org.hipparchus.util.FastMath;
58 import org.hipparchus.util.Pair;
59 import org.hipparchus.util.SinCos;
60
61
62
63
64
65
66
67 public class ClusterAlgorithmComparison {
68
69
70
71
72
73
74
75
76 public ClusterAlgorithmComparison() {
77
78 }
79
80
81
82
83
84
85
86
87
88 public static List<Vector2D> makeCircles(int samples, boolean shuffle, double noise, double factor, final RandomGenerator random) {
89 if (factor < 0 || factor > 1) {
90 throw new IllegalArgumentException();
91 }
92
93 List<Vector2D> points = new ArrayList<Vector2D>();
94 double range = 2.0 * FastMath.PI;
95 double step = range / (samples / 2.0 + 1);
96 for (double angle = 0; angle < range; angle += step) {
97 Vector2D outerCircle = buildVector(angle);
98 Vector2D innerCircle = outerCircle.scalarMultiply(factor);
99
100 points.add(outerCircle.add(generateNoiseVector(random, noise)));
101 points.add(innerCircle.add(generateNoiseVector(random, noise)));
102 }
103
104 if (shuffle) {
105 Collections.shuffle(points, new RandomAdaptor(random));
106 }
107
108 return points;
109 }
110
111
112
113
114
115
116
117
118 public static List<Vector2D> makeMoons(int samples, boolean shuffle, double noise, RandomGenerator random) {
119
120 int nSamplesOut = samples / 2;
121 int nSamplesIn = samples - nSamplesOut;
122
123 List<Vector2D> points = new ArrayList<Vector2D>();
124 double range = FastMath.PI;
125 double step = range / (nSamplesOut / 2.0);
126 for (double angle = 0; angle < range; angle += step) {
127 Vector2D outerCircle = buildVector(angle);
128 points.add(outerCircle.add(generateNoiseVector(random, noise)));
129 }
130
131 step = range / (nSamplesIn / 2.0);
132 for (double angle = 0; angle < range; angle += step) {
133 final SinCos sc = FastMath.sinCos(angle);
134 Vector2D innerCircle = new Vector2D(1 - sc.cos(), 1 - sc.sin() - 0.5);
135 points.add(innerCircle.add(generateNoiseVector(random, noise)));
136 }
137
138 if (shuffle) {
139 Collections.shuffle(points, new RandomAdaptor(random));
140 }
141
142 return points;
143 }
144
145
146
147
148
149
150
151
152
153
154
155 public static List<Vector2D> makeBlobs(int samples, int centers, double clusterStd,
156 double min, double max, boolean shuffle, RandomGenerator random) {
157
158 final RandomDataGenerator randomDataGenerator = RandomDataGenerator.of(random);
159
160
161 double range = max - min;
162 Vector2D[] centerPoints = new Vector2D[centers];
163 for (int i = 0; i < centers; i++) {
164 double x = random.nextDouble() * range + min;
165 double y = random.nextDouble() * range + min;
166 centerPoints[i] = new Vector2D(x, y);
167 }
168
169 int[] nSamplesPerCenter = new int[centers];
170 int count = samples / centers;
171 Arrays.fill(nSamplesPerCenter, count);
172
173 for (int i = 0; i < samples % centers; i++) {
174 nSamplesPerCenter[i]++;
175 }
176
177 List<Vector2D> points = new ArrayList<Vector2D>();
178 for (int i = 0; i < centers; i++) {
179 for (int j = 0; j < nSamplesPerCenter[i]; j++) {
180 Vector2D point = new Vector2D(randomDataGenerator.nextNormal(0, clusterStd),
181 randomDataGenerator.nextNormal(0, clusterStd));
182 points.add(point.add(centerPoints[i]));
183 }
184 }
185
186 if (shuffle) {
187 Collections.shuffle(points, new RandomAdaptor(random));
188 }
189
190 return points;
191 }
192
193
194
195
196
197 public static List<Vector2D> makeSobol(int samples) {
198 SobolSequenceGenerator generator = new SobolSequenceGenerator(2);
199 generator.skipTo(999999);
200 List<Vector2D> points = new ArrayList<Vector2D>();
201 for (double i = 0; i < samples; i++) {
202 double[] vector = generator.nextVector();
203 vector[0] = vector[0] * 2 - 1;
204 vector[1] = vector[1] * 2 - 1;
205 Vector2D point = new Vector2D(vector);
206 points.add(point);
207 }
208
209 return points;
210 }
211
212
213
214
215
216
217 public static Vector2D generateNoiseVector(RandomGenerator randomGenerator, double noise) {
218 final RandomDataGenerator randomDataGenerator = RandomDataGenerator.of(randomGenerator);
219 return new Vector2D(randomDataGenerator.nextNormal(0, noise), randomDataGenerator.nextNormal(0, noise));
220 }
221
222
223
224
225
226
227
228
229
230 public static List<DoublePoint> normalize(final List<Vector2D> input, double minX, double maxX, double minY, double maxY) {
231 double rangeX = maxX - minX;
232 double rangeY = maxY - minY;
233 List<DoublePoint> points = new ArrayList<DoublePoint>();
234 for (Vector2D p : input) {
235 double[] arr = p.toArray();
236 arr[0] = (arr[0] - minX) / rangeX * 2 - 1;
237 arr[1] = (arr[1] - minY) / rangeY * 2 - 1;
238 points.add(new DoublePoint(arr));
239 }
240 return points;
241 }
242
243
244
245
246
247
248 private static Vector2D buildVector(final double alpha) {
249 final SinCos sc = FastMath.sinCos(alpha);
250 return new Vector2D(sc.cos(), sc.sin());
251 }
252
253
254 @SuppressWarnings("serial")
255 public static class Display extends ExampleFrame {
256
257
258 public Display() {
259 setTitle("Hipparchus: Cluster algorithm comparison");
260 setSize(800, 800);
261
262 setLayout(new GridBagLayout());
263
264 int nSamples = 1500;
265
266 RandomGenerator rng = new Well19937c(0);
267 List<List<DoublePoint>> datasets = new ArrayList<List<DoublePoint>>();
268
269 datasets.add(normalize(makeCircles(nSamples, true, 0.04, 0.5, rng), -1, 1, -1, 1));
270 datasets.add(normalize(makeMoons(nSamples, true, 0.04, rng), -1, 2, -1, 1));
271 datasets.add(normalize(makeBlobs(nSamples, 3, 1.0, -10, 10, true, rng), -12, 12, -12, 12));
272 datasets.add(normalize(makeSobol(nSamples), -1, 1, -1, 1));
273
274 List<Pair<String, Clusterer<DoublePoint>>> algorithms = new ArrayList<Pair<String, Clusterer<DoublePoint>>>();
275
276 algorithms.add(new Pair<String, Clusterer<DoublePoint>>("KMeans\n(k=2)", new KMeansPlusPlusClusterer<DoublePoint>(2)));
277 algorithms.add(new Pair<String, Clusterer<DoublePoint>>("KMeans\n(k=3)", new KMeansPlusPlusClusterer<DoublePoint>(3)));
278 algorithms.add(new Pair<String, Clusterer<DoublePoint>>("FuzzyKMeans\n(k=3, fuzzy=2)", new FuzzyKMeansClusterer<DoublePoint>(3, 2)));
279 algorithms.add(new Pair<String, Clusterer<DoublePoint>>("FuzzyKMeans\n(k=3, fuzzy=10)", new FuzzyKMeansClusterer<DoublePoint>(3, 10)));
280 algorithms.add(new Pair<String, Clusterer<DoublePoint>>("DBSCAN\n(eps=.1, min=3)", new DBSCANClusterer<DoublePoint>(0.1, 3)));
281
282 GridBagConstraints c = new GridBagConstraints();
283 c.fill = GridBagConstraints.VERTICAL;
284 c.gridx = 0;
285 c.gridy = 0;
286 c.insets = new Insets(2, 2, 2, 2);
287
288 for (Pair<String, Clusterer<DoublePoint>> pair : algorithms) {
289 JLabel text = new JLabel("<html><body>" + pair.getFirst().replace("\n", "<br>"));
290 add(text, c);
291 c.gridx++;
292 }
293 c.gridy++;
294
295 for (List<DoublePoint> dataset : datasets) {
296 c.gridx = 0;
297 for (Pair<String, Clusterer<DoublePoint>> pair : algorithms) {
298 long start = System.currentTimeMillis();
299 List<? extends Cluster<DoublePoint>> clusters = pair.getSecond().cluster(dataset);
300 long end = System.currentTimeMillis();
301 add(new ClusterPlot(clusters, end - start), c);
302 c.gridx++;
303 }
304 c.gridy++;
305 }
306 }
307
308 }
309
310
311 @SuppressWarnings("serial")
312 public static class ClusterPlot extends JComponent {
313
314
315 private static final double PAD = 10;
316
317
318 private List<? extends Cluster<DoublePoint>> clusters;
319
320
321 private long duration;
322
323
324
325
326
327 public ClusterPlot(final List<? extends Cluster<DoublePoint>> clusters, long duration) {
328 this.clusters = clusters;
329 this.duration = duration;
330 }
331
332 @Override
333 protected void paintComponent(Graphics g) {
334 super.paintComponent(g);
335 Graphics2D g2 = (Graphics2D)g;
336 g2.setRenderingHint(RenderingHints.KEY_ANTIALIASING,
337 RenderingHints.VALUE_ANTIALIAS_ON);
338
339 int w = getWidth();
340 int h = getHeight();
341
342 g2.clearRect(0, 0, w, h);
343
344 g2.setPaint(Color.black);
345 g2.drawRect(0, 0, w - 1, h - 1);
346
347 int index = 0;
348 Color[] colors = new Color[] { Color.red, Color.blue, Color.green.darker() };
349 for (Cluster<DoublePoint> cluster : clusters) {
350 g2.setPaint(colors[index++]);
351 for (DoublePoint point : cluster.getPoints()) {
352 Clusterable p = transform(point, w, h);
353 double[] arr = p.getPoint();
354 g2.fill(new Ellipse2D.Double(arr[0] - 1, arr[1] - 1, 3, 3));
355 }
356
357 if (cluster instanceof CentroidCluster) {
358 Clusterable p = transform(((CentroidCluster<?>) cluster).getCenter(), w, h);
359 double[] arr = p.getPoint();
360 Shape s = new Ellipse2D.Double(arr[0] - 4, arr[1] - 4, 8, 8);
361 g2.fill(s);
362 g2.setPaint(Color.black);
363 g2.draw(s);
364 }
365 }
366
367 g2.setPaint(Color.black);
368 g2.drawString(String.format("%.2f s", duration / 1e3), w - 40, h - 5);
369 }
370
371 @Override
372 public Dimension getPreferredSize() {
373 return new Dimension(150, 150);
374 }
375
376 private Clusterable transform(Clusterable point, int width, int height) {
377 double[] arr = point.getPoint();
378 return new DoublePoint(new double[] { PAD + (arr[0] + 1) / 2.0 * (width - 2 * PAD),
379 height - PAD - (arr[1] + 1) / 2.0 * (height - 2 * PAD) });
380 }
381 }
382
383
384
385
386 public static void main(String[] args) {
387 ExampleUtils.showExampleFrame(new Display());
388 }
389
390 }