1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22 package org.hipparchus.clustering;
23
24 import java.util.ArrayList;
25 import java.util.Collection;
26 import java.util.Collections;
27 import java.util.List;
28
29 import org.hipparchus.clustering.distance.DistanceMeasure;
30 import org.hipparchus.clustering.distance.EuclideanDistance;
31 import org.hipparchus.exception.LocalizedCoreFormats;
32 import org.hipparchus.exception.MathIllegalArgumentException;
33 import org.hipparchus.exception.MathIllegalStateException;
34 import org.hipparchus.linear.MatrixUtils;
35 import org.hipparchus.linear.RealMatrix;
36 import org.hipparchus.random.JDKRandomGenerator;
37 import org.hipparchus.random.RandomGenerator;
38 import org.hipparchus.util.FastMath;
39 import org.hipparchus.util.MathArrays;
40 import org.hipparchus.util.MathUtils;
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72 public class FuzzyKMeansClusterer<T extends Clusterable> extends Clusterer<T> {
73
74
75 private static final double DEFAULT_EPSILON = 1e-3;
76
77
78 private final int k;
79
80
81 private final int maxIterations;
82
83
84 private final double fuzziness;
85
86
87 private final double epsilon;
88
89
90 private final RandomGenerator random;
91
92
93 private double[][] membershipMatrix;
94
95
96 private List<T> points;
97
98
99 private List<CentroidCluster<T>> clusters;
100
101
102
103
104
105
106
107
108
109
110 public FuzzyKMeansClusterer(final int k, final double fuzziness) throws MathIllegalArgumentException {
111 this(k, fuzziness, -1, new EuclideanDistance());
112 }
113
114
115
116
117
118
119
120
121
122
123
124 public FuzzyKMeansClusterer(final int k, final double fuzziness,
125 final int maxIterations, final DistanceMeasure measure)
126 throws MathIllegalArgumentException {
127 this(k, fuzziness, maxIterations, measure, DEFAULT_EPSILON, new JDKRandomGenerator());
128 }
129
130
131
132
133
134
135
136
137
138
139
140
141
142 public FuzzyKMeansClusterer(final int k, final double fuzziness,
143 final int maxIterations, final DistanceMeasure measure,
144 final double epsilon, final RandomGenerator random)
145 throws MathIllegalArgumentException {
146
147 super(measure);
148
149 if (fuzziness <= 1.0d) {
150 throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL_BOUND_EXCLUDED,
151 fuzziness, 1.0);
152 }
153 this.k = k;
154 this.fuzziness = fuzziness;
155 this.maxIterations = maxIterations;
156 this.epsilon = epsilon;
157 this.random = random;
158
159 this.membershipMatrix = null;
160 this.points = null;
161 this.clusters = null;
162 }
163
164
165
166
167
168 public int getK() {
169 return k;
170 }
171
172
173
174
175
176 public double getFuzziness() {
177 return fuzziness;
178 }
179
180
181
182
183
184 public int getMaxIterations() {
185 return maxIterations;
186 }
187
188
189
190
191
192 public double getEpsilon() {
193 return epsilon;
194 }
195
196
197
198
199
200 public RandomGenerator getRandomGenerator() {
201 return random;
202 }
203
204
205
206
207
208
209
210
211
212
213
214 public RealMatrix getMembershipMatrix() {
215 if (membershipMatrix == null) {
216 throw new MathIllegalStateException(LocalizedCoreFormats.ILLEGAL_STATE);
217 }
218 return MatrixUtils.createRealMatrix(membershipMatrix);
219 }
220
221
222
223
224
225
226
227 public List<T> getDataPoints() {
228 return points;
229 }
230
231
232
233
234
235
236 public List<CentroidCluster<T>> getClusters() {
237 return clusters;
238 }
239
240
241
242
243
244
245 public double getObjectiveFunctionValue() {
246 if (points == null || clusters == null) {
247 throw new MathIllegalStateException(LocalizedCoreFormats.ILLEGAL_STATE);
248 }
249
250 int i = 0;
251 double objFunction = 0.0;
252 for (final T point : points) {
253 int j = 0;
254 for (final CentroidCluster<T> cluster : clusters) {
255 final double dist = distance(point, cluster.getCenter());
256 objFunction += (dist * dist) * FastMath.pow(membershipMatrix[i][j], fuzziness);
257 j++;
258 }
259 i++;
260 }
261 return objFunction;
262 }
263
264
265
266
267
268
269
270
271
272 @Override
273 public List<CentroidCluster<T>> cluster(final Collection<T> dataPoints)
274 throws MathIllegalArgumentException {
275
276
277 MathUtils.checkNotNull(dataPoints);
278
279 final int size = dataPoints.size();
280
281
282 if (size < k) {
283 throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL_BOUND_EXCLUDED,
284 size, k);
285 }
286
287
288 points = Collections.unmodifiableList(new ArrayList<>(dataPoints));
289 clusters = new ArrayList<>();
290 membershipMatrix = new double[size][k];
291 final double[][] oldMatrix = new double[size][k];
292
293
294 if (size == 0) {
295 return clusters;
296 }
297
298 initializeMembershipMatrix();
299
300
301 final int pointDimension = points.get(0).getPoint().length;
302 for (int i = 0; i < k; i++) {
303 clusters.add(new CentroidCluster<T>(new DoublePoint(new double[pointDimension])));
304 }
305
306 int iteration = 0;
307 final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
308 double difference;
309
310 do {
311 saveMembershipMatrix(oldMatrix);
312 updateClusterCenters();
313 updateMembershipMatrix();
314 difference = calculateMaxMembershipChange(oldMatrix);
315 } while (difference > epsilon && ++iteration < max);
316
317 return clusters;
318 }
319
320
321
322
323 private void updateClusterCenters() {
324 int j = 0;
325 final List<CentroidCluster<T>> newClusters = new ArrayList<>(k);
326 for (final CentroidCluster<T> cluster : clusters) {
327 final Clusterable center = cluster.getCenter();
328 int i = 0;
329 double[] arr = new double[center.getPoint().length];
330 double sum = 0.0;
331 for (final T point : points) {
332 final double u = FastMath.pow(membershipMatrix[i][j], fuzziness);
333 final double[] pointArr = point.getPoint();
334 for (int idx = 0; idx < arr.length; idx++) {
335 arr[idx] += u * pointArr[idx];
336 }
337 sum += u;
338 i++;
339 }
340 MathArrays.scaleInPlace(1.0 / sum, arr);
341 newClusters.add(new CentroidCluster<T>(new DoublePoint(arr)));
342 j++;
343 }
344 clusters.clear();
345 clusters = newClusters;
346 }
347
348
349
350
351
352 private void updateMembershipMatrix() {
353 for (int i = 0; i < points.size(); i++) {
354 final T point = points.get(i);
355 double maxMembership = Double.MIN_VALUE;
356 int newCluster = -1;
357 for (int j = 0; j < clusters.size(); j++) {
358 double sum = 0.0;
359 final double distA = FastMath.abs(distance(point, clusters.get(j).getCenter()));
360
361 if (distA != 0.0) {
362 for (final CentroidCluster<T> c : clusters) {
363 final double distB = FastMath.abs(distance(point, c.getCenter()));
364 if (distB == 0.0) {
365 sum = Double.POSITIVE_INFINITY;
366 break;
367 }
368 sum += FastMath.pow(distA / distB, 2.0 / (fuzziness - 1.0));
369 }
370 }
371
372 double membership;
373 if (sum == 0.0) {
374 membership = 1.0;
375 } else if (sum == Double.POSITIVE_INFINITY) {
376 membership = 0.0;
377 } else {
378 membership = 1.0 / sum;
379 }
380 membershipMatrix[i][j] = membership;
381
382 if (membershipMatrix[i][j] > maxMembership) {
383 maxMembership = membershipMatrix[i][j];
384 newCluster = j;
385 }
386 }
387 clusters.get(newCluster).addPoint(point);
388 }
389 }
390
391
392
393
394 private void initializeMembershipMatrix() {
395 for (int i = 0; i < points.size(); i++) {
396 for (int j = 0; j < k; j++) {
397 membershipMatrix[i][j] = random.nextDouble();
398 }
399 membershipMatrix[i] = MathArrays.normalizeArray(membershipMatrix[i], 1.0);
400 }
401 }
402
403
404
405
406
407
408
409
410 private double calculateMaxMembershipChange(final double[][] matrix) {
411 double maxMembership = 0.0;
412 for (int i = 0; i < points.size(); i++) {
413 for (int j = 0; j < clusters.size(); j++) {
414 double v = FastMath.abs(membershipMatrix[i][j] - matrix[i][j]);
415 maxMembership = FastMath.max(v, maxMembership);
416 }
417 }
418 return maxMembership;
419 }
420
421
422
423
424
425
426 private void saveMembershipMatrix(final double[][] matrix) {
427 for (int i = 0; i < points.size(); i++) {
428 System.arraycopy(membershipMatrix[i], 0, matrix[i], 0, clusters.size());
429 }
430 }
431
432 }