FuzzyKMeansClusterer.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      https://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. /*
  18.  * This is not the original file distributed by the Apache Software Foundation
  19.  * It has been modified by the Hipparchus project
  20.  */
  21. package org.hipparchus.clustering;

  22. import java.util.ArrayList;
  23. import java.util.Collection;
  24. import java.util.Collections;
  25. import java.util.List;

  26. import org.hipparchus.clustering.distance.DistanceMeasure;
  27. import org.hipparchus.clustering.distance.EuclideanDistance;
  28. import org.hipparchus.exception.LocalizedCoreFormats;
  29. import org.hipparchus.exception.MathIllegalArgumentException;
  30. import org.hipparchus.exception.MathIllegalStateException;
  31. import org.hipparchus.linear.MatrixUtils;
  32. import org.hipparchus.linear.RealMatrix;
  33. import org.hipparchus.random.JDKRandomGenerator;
  34. import org.hipparchus.random.RandomGenerator;
  35. import org.hipparchus.util.FastMath;
  36. import org.hipparchus.util.MathArrays;
  37. import org.hipparchus.util.MathUtils;

  38. /**
  39.  * Fuzzy K-Means clustering algorithm.
  40.  * <p>
  41.  * The Fuzzy K-Means algorithm is a variation of the classical K-Means algorithm, with the
  42.  * major difference that a single data point is not uniquely assigned to a single cluster.
  43.  * Instead, each point i has a set of weights u<sub>ij</sub> which indicate the degree of membership
  44.  * to the cluster j.
  45.  * <p>The algorithm then tries to minimize the objective function:
  46.  * \[
  47.  * J = \sum_{i=1}^C\sum_{k=1]{N} u_{i,k}^m d_{i,k}^2
  48.  * \]
  49.  * with \(d_{i,k}\) being the distance between data point i and the cluster center k.
  50.  * </p>
  51.  * <p>The algorithm requires two parameters:</p>
  52.  * <ul>
  53.  *   <li>k: the number of clusters
  54.  *   <li>fuzziness: determines the level of cluster fuzziness, larger values lead to fuzzier clusters
  55.  * </ul>
  56.  * <p>Additional, optional parameters:</p>
  57.  * <ul>
  58.  *   <li>maxIterations: the maximum number of iterations
  59.  *   <li>epsilon: the convergence criteria, default is 1e-3
  60.  * </ul>
  61.  * <p>
  62.  * The fuzzy variant of the K-Means algorithm is more robust with regard to the selection
  63.  * of the initial cluster centers.
  64.  * </p>
  65.  *
  66.  * @param <T> type of the points to cluster
  67.  */
  68. public class FuzzyKMeansClusterer<T extends Clusterable> extends Clusterer<T> {

  69.     /** The default value for the convergence criteria. */
  70.     private static final double DEFAULT_EPSILON = 1e-3;

  71.     /** The number of clusters. */
  72.     private final int k;

  73.     /** The maximum number of iterations. */
  74.     private final int maxIterations;

  75.     /** The fuzziness factor. */
  76.     private final double fuzziness;

  77.     /** The convergence criteria. */
  78.     private final double epsilon;

  79.     /** Random generator for choosing initial centers. */
  80.     private final RandomGenerator random;

  81.     /** The membership matrix. */
  82.     private double[][] membershipMatrix;

  83.     /** The list of points used in the last call to {@link #cluster(Collection)}. */
  84.     private List<T> points;

  85.     /** The list of clusters resulting from the last call to {@link #cluster(Collection)}. */
  86.     private List<CentroidCluster<T>> clusters;

  87.     /**
  88.      * Creates a new instance of a FuzzyKMeansClusterer.
  89.      * <p>
  90.      * The euclidean distance will be used as default distance measure.
  91.      *
  92.      * @param k the number of clusters to split the data into
  93.      * @param fuzziness the fuzziness factor, must be &gt; 1.0
  94.      * @throws MathIllegalArgumentException if {@code fuzziness <= 1.0}
  95.      */
  96.     public FuzzyKMeansClusterer(final int k, final double fuzziness) throws MathIllegalArgumentException {
  97.         this(k, fuzziness, -1, new EuclideanDistance());
  98.     }

  99.     /**
  100.      * Creates a new instance of a FuzzyKMeansClusterer.
  101.      *
  102.      * @param k the number of clusters to split the data into
  103.      * @param fuzziness the fuzziness factor, must be &gt; 1.0
  104.      * @param maxIterations the maximum number of iterations to run the algorithm for.
  105.      *   If negative, no maximum will be used.
  106.      * @param measure the distance measure to use
  107.      * @throws MathIllegalArgumentException if {@code fuzziness <= 1.0}
  108.      */
  109.     public FuzzyKMeansClusterer(final int k, final double fuzziness,
  110.                                 final int maxIterations, final DistanceMeasure measure)
  111.             throws MathIllegalArgumentException {
  112.         this(k, fuzziness, maxIterations, measure, DEFAULT_EPSILON, new JDKRandomGenerator());
  113.     }

  114.     /**
  115.      * Creates a new instance of a FuzzyKMeansClusterer.
  116.      *
  117.      * @param k the number of clusters to split the data into
  118.      * @param fuzziness the fuzziness factor, must be &gt; 1.0
  119.      * @param maxIterations the maximum number of iterations to run the algorithm for.
  120.      *   If negative, no maximum will be used.
  121.      * @param measure the distance measure to use
  122.      * @param epsilon the convergence criteria (default is 1e-3)
  123.      * @param random random generator to use for choosing initial centers
  124.      * @throws MathIllegalArgumentException if {@code fuzziness <= 1.0}
  125.      */
  126.     public FuzzyKMeansClusterer(final int k, final double fuzziness,
  127.                                 final int maxIterations, final DistanceMeasure measure,
  128.                                 final double epsilon, final RandomGenerator random)
  129.             throws MathIllegalArgumentException {

  130.         super(measure);

  131.         if (fuzziness <= 1.0d) {
  132.             throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL_BOUND_EXCLUDED,
  133.                                                    fuzziness, 1.0);
  134.         }
  135.         this.k = k;
  136.         this.fuzziness = fuzziness;
  137.         this.maxIterations = maxIterations;
  138.         this.epsilon = epsilon;
  139.         this.random = random;

  140.         this.membershipMatrix = null;
  141.         this.points = null;
  142.         this.clusters = null;
  143.     }

  144.     /**
  145.      * Return the number of clusters this instance will use.
  146.      * @return the number of clusters
  147.      */
  148.     public int getK() {
  149.         return k;
  150.     }

  151.     /**
  152.      * Returns the fuzziness factor used by this instance.
  153.      * @return the fuzziness factor
  154.      */
  155.     public double getFuzziness() {
  156.         return fuzziness;
  157.     }

  158.     /**
  159.      * Returns the maximum number of iterations this instance will use.
  160.      * @return the maximum number of iterations, or -1 if no maximum is set
  161.      */
  162.     public int getMaxIterations() {
  163.         return maxIterations;
  164.     }

  165.     /**
  166.      * Returns the convergence criteria used by this instance.
  167.      * @return the convergence criteria
  168.      */
  169.     public double getEpsilon() {
  170.         return epsilon;
  171.     }

  172.     /**
  173.      * Returns the random generator this instance will use.
  174.      * @return the random generator
  175.      */
  176.     public RandomGenerator getRandomGenerator() {
  177.         return random;
  178.     }

  179.     /**
  180.      * Returns the {@code nxk} membership matrix, where {@code n} is the number
  181.      * of data points and {@code k} the number of clusters.
  182.      * <p>
  183.      * The element U<sub>i,j</sub> represents the membership value for data point {@code i}
  184.      * to cluster {@code j}.
  185.      *
  186.      * @return the membership matrix
  187.      * @throws MathIllegalStateException if {@link #cluster(Collection)} has not been called before
  188.      */
  189.     public RealMatrix getMembershipMatrix() {
  190.         if (membershipMatrix == null) {
  191.             throw new MathIllegalStateException(LocalizedCoreFormats.ILLEGAL_STATE);
  192.         }
  193.         return MatrixUtils.createRealMatrix(membershipMatrix);
  194.     }

  195.     /**
  196.      * Returns an unmodifiable list of the data points used in the last
  197.      * call to {@link #cluster(Collection)}.
  198.      * @return the list of data points, or {@code null} if {@link #cluster(Collection)} has
  199.      *   not been called before.
  200.      */
  201.     public List<T> getDataPoints() {
  202.         return points;
  203.     }

  204.     /**
  205.      * Returns the list of clusters resulting from the last call to {@link #cluster(Collection)}.
  206.      * @return the list of clusters, or {@code null} if {@link #cluster(Collection)} has
  207.      *   not been called before.
  208.      */
  209.     public List<CentroidCluster<T>> getClusters() {
  210.         return clusters;
  211.     }

  212.     /**
  213.      * Get the value of the objective function.
  214.      * @return the objective function evaluation as double value
  215.      * @throws MathIllegalStateException if {@link #cluster(Collection)} has not been called before
  216.      */
  217.     public double getObjectiveFunctionValue() {
  218.         if (points == null || clusters == null) {
  219.             throw new MathIllegalStateException(LocalizedCoreFormats.ILLEGAL_STATE);
  220.         }

  221.         int i = 0;
  222.         double objFunction = 0.0;
  223.         for (final T point : points) {
  224.             int j = 0;
  225.             for (final CentroidCluster<T> cluster : clusters) {
  226.                 final double dist = distance(point, cluster.getCenter());
  227.                 objFunction += (dist * dist) * FastMath.pow(membershipMatrix[i][j], fuzziness);
  228.                 j++;
  229.             }
  230.             i++;
  231.         }
  232.         return objFunction;
  233.     }

  234.     /**
  235.      * Performs Fuzzy K-Means cluster analysis.
  236.      *
  237.      * @param dataPoints the points to cluster
  238.      * @return the list of clusters
  239.      * @throws MathIllegalArgumentException if the data points are null or the number
  240.      *     of clusters is larger than the number of data points
  241.      */
  242.     @Override
  243.     public List<CentroidCluster<T>> cluster(final Collection<T> dataPoints)
  244.             throws MathIllegalArgumentException {

  245.         // sanity checks
  246.         MathUtils.checkNotNull(dataPoints);

  247.         final int size = dataPoints.size();

  248.         // number of clusters has to be smaller or equal the number of data points
  249.         if (size < k) {
  250.             throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL_BOUND_EXCLUDED,
  251.                                                    size, k);
  252.         }

  253.         // copy the input collection to an unmodifiable list with indexed access
  254.         points = Collections.unmodifiableList(new ArrayList<>(dataPoints));
  255.         clusters = new ArrayList<>();
  256.         membershipMatrix = new double[size][k];
  257.         final double[][] oldMatrix = new double[size][k];

  258.         // if no points are provided, return an empty list of clusters
  259.         if (size == 0) {
  260.             return clusters;
  261.         }

  262.         initializeMembershipMatrix();

  263.         // there is at least one point
  264.         final int pointDimension = points.get(0).getPoint().length;
  265.         for (int i = 0; i < k; i++) {
  266.             clusters.add(new CentroidCluster<>(new DoublePoint(new double[pointDimension])));
  267.         }

  268.         int iteration = 0;
  269.         final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
  270.         double difference;

  271.         do {
  272.             saveMembershipMatrix(oldMatrix);
  273.             updateClusterCenters();
  274.             updateMembershipMatrix();
  275.             difference = calculateMaxMembershipChange(oldMatrix);
  276.         } while (difference > epsilon && ++iteration < max);

  277.         return clusters;
  278.     }

  279.     /**
  280.      * Update the cluster centers.
  281.      */
  282.     private void updateClusterCenters() {
  283.         int j = 0;
  284.         final List<CentroidCluster<T>> newClusters = new ArrayList<>(k);
  285.         for (final CentroidCluster<T> cluster : clusters) {
  286.             final Clusterable center = cluster.getCenter();
  287.             int i = 0;
  288.             double[] arr = new double[center.getPoint().length];
  289.             double sum = 0.0;
  290.             for (final T point : points) {
  291.                 final double u = FastMath.pow(membershipMatrix[i][j], fuzziness);
  292.                 final double[] pointArr = point.getPoint();
  293.                 for (int idx = 0; idx < arr.length; idx++) {
  294.                     arr[idx] += u * pointArr[idx];
  295.                 }
  296.                 sum += u;
  297.                 i++;
  298.             }
  299.             MathArrays.scaleInPlace(1.0 / sum, arr);
  300.             newClusters.add(new CentroidCluster<>(new DoublePoint(arr)));
  301.             j++;
  302.         }
  303.         clusters.clear();
  304.         clusters = newClusters;
  305.     }

  306.     /**
  307.      * Updates the membership matrix and assigns the points to the cluster with
  308.      * the highest membership.
  309.      */
  310.     private void updateMembershipMatrix() {
  311.         for (int i = 0; i < points.size(); i++) {
  312.             final T point = points.get(i);
  313.             double maxMembership = Double.MIN_VALUE;
  314.             int newCluster = -1;
  315.             for (int j = 0; j < clusters.size(); j++) {
  316.                 double sum = 0.0;
  317.                 final double distA = FastMath.abs(distance(point, clusters.get(j).getCenter()));

  318.                 if (distA != 0.0) {
  319.                     for (final CentroidCluster<T> c : clusters) {
  320.                         final double distB = FastMath.abs(distance(point, c.getCenter()));
  321.                         if (distB == 0.0) {
  322.                             sum = Double.POSITIVE_INFINITY;
  323.                             break;
  324.                         }
  325.                         sum += FastMath.pow(distA / distB, 2.0 / (fuzziness - 1.0));
  326.                     }
  327.                 }

  328.                 double membership;
  329.                 if (sum == 0.0) {
  330.                     membership = 1.0;
  331.                 } else if (sum == Double.POSITIVE_INFINITY) {
  332.                     membership = 0.0;
  333.                 } else {
  334.                     membership = 1.0 / sum;
  335.                 }
  336.                 membershipMatrix[i][j] = membership;

  337.                 if (membershipMatrix[i][j] > maxMembership) {
  338.                     maxMembership = membershipMatrix[i][j];
  339.                     newCluster = j;
  340.                 }
  341.             }
  342.             clusters.get(newCluster).addPoint(point);
  343.         }
  344.     }

  345.     /**
  346.      * Initialize the membership matrix with random values.
  347.      */
  348.     private void initializeMembershipMatrix() {
  349.         for (int i = 0; i < points.size(); i++) {
  350.             for (int j = 0; j < k; j++) {
  351.                 membershipMatrix[i][j] = random.nextDouble();
  352.             }
  353.             membershipMatrix[i] = MathArrays.normalizeArray(membershipMatrix[i], 1.0);
  354.         }
  355.     }

  356.     /**
  357.      * Calculate the maximum element-by-element change of the membership matrix
  358.      * for the current iteration.
  359.      *
  360.      * @param matrix the membership matrix of the previous iteration
  361.      * @return the maximum membership matrix change
  362.      */
  363.     private double calculateMaxMembershipChange(final double[][] matrix) {
  364.         double maxMembership = 0.0;
  365.         for (int i = 0; i < points.size(); i++) {
  366.             for (int j = 0; j < clusters.size(); j++) {
  367.                 double v = FastMath.abs(membershipMatrix[i][j] - matrix[i][j]);
  368.                 maxMembership = FastMath.max(v, maxMembership);
  369.             }
  370.         }
  371.         return maxMembership;
  372.     }

  373.     /**
  374.      * Copy the membership matrix into the provided matrix.
  375.      *
  376.      * @param matrix the place to store the membership matrix
  377.      */
  378.     private void saveMembershipMatrix(final double[][] matrix) {
  379.         for (int i = 0; i < points.size(); i++) {
  380.             System.arraycopy(membershipMatrix[i], 0, matrix[i], 0, clusters.size());
  381.         }
  382.     }

  383. }