KendallsCorrelation.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      https://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. /*
  18.  * This is not the original file distributed by the Apache Software Foundation
  19.  * It has been modified by the Hipparchus project
  20.  */
  21. package org.hipparchus.stat.correlation;

  22. import java.util.Arrays;

  23. import org.hipparchus.exception.MathIllegalArgumentException;
  24. import org.hipparchus.linear.BlockRealMatrix;
  25. import org.hipparchus.linear.MatrixUtils;
  26. import org.hipparchus.linear.RealMatrix;
  27. import org.hipparchus.util.FastMath;
  28. import org.hipparchus.util.MathArrays;

  29. /**
  30.  * Implementation of Kendall's Tau-b rank correlation.
  31.  * <p>
  32.  * A pair of observations (x<sub>1</sub>, y<sub>1</sub>) and
  33.  * (x<sub>2</sub>, y<sub>2</sub>) are considered <i>concordant</i> if
  34.  * x<sub>1</sub> &lt; x<sub>2</sub> and y<sub>1</sub> &lt; y<sub>2</sub>
  35.  * or x<sub>2</sub> &lt; x<sub>1</sub> and y<sub>2</sub> &lt; y<sub>1</sub>.
  36.  * The pair is <i>discordant</i> if x<sub>1</sub> &lt; x<sub>2</sub> and
  37.  * y<sub>2</sub> &lt; y<sub>1</sub> or x<sub>2</sub> &lt; x<sub>1</sub> and
  38.  * y<sub>1</sub> &lt; y<sub>2</sub>.  If either x<sub>1</sub> = x<sub>2</sub>
  39.  * or y<sub>1</sub> = y<sub>2</sub>, the pair is neither concordant nor
  40.  * discordant.
  41.  * <p>
  42.  * Kendall's Tau-b is defined as:
  43.  * \[
  44.  * \tau_b = \frac{n_c - n_d}{\sqrt{(n_0 - n_1) (n_0 - n_2)}}
  45.  * \]
  46.  * <p>
  47.  * where:
  48.  * <ul>
  49.  *     <li>n<sub>0</sub> = n * (n - 1) / 2</li>
  50.  *     <li>n<sub>c</sub> = Number of concordant pairs</li>
  51.  *     <li>n<sub>d</sub> = Number of discordant pairs</li>
  52.  *     <li>n<sub>1</sub> = sum of t<sub>i</sub> * (t<sub>i</sub> - 1) / 2 for all i</li>
  53.  *     <li>n<sub>2</sub> = sum of u<sub>j</sub> * (u<sub>j</sub> - 1) / 2 for all j</li>
  54.  *     <li>t<sub>i</sub> = Number of tied values in the i<sup>th</sup> group of ties in x</li>
  55.  *     <li>u<sub>j</sub> = Number of tied values in the j<sup>th</sup> group of ties in y</li>
  56.  * </ul>
  57.  * <p>
  58.  * This implementation uses the O(n log n) algorithm described in
  59.  * William R. Knight's 1966 paper "A Computer Method for Calculating
  60.  * Kendall's Tau with Ungrouped Data" in the Journal of the American
  61.  * Statistical Association.
  62.  *
  63.  * @see <a href="http://en.wikipedia.org/wiki/Kendall_tau_rank_correlation_coefficient">
  64.  * Kendall tau rank correlation coefficient (Wikipedia)</a>
  65.  * @see <a href="http://www.jstor.org/stable/2282833">A Computer
  66.  * Method for Calculating Kendall's Tau with Ungrouped Data</a>
  67.  */
  68. public class KendallsCorrelation {

  69.     /** correlation matrix */
  70.     private final RealMatrix correlationMatrix;

  71.     /**
  72.      * Create a KendallsCorrelation instance without data.
  73.      */
  74.     public KendallsCorrelation() {
  75.         correlationMatrix = null;
  76.     }

  77.     /**
  78.      * Create a KendallsCorrelation from a rectangular array
  79.      * whose columns represent values of variables to be correlated.
  80.      *
  81.      * @param data rectangular array with columns representing variables
  82.      * @throws IllegalArgumentException if the input data array is not
  83.      * rectangular with at least two rows and two columns.
  84.      */
  85.     public KendallsCorrelation(double[][] data) {
  86.         this(MatrixUtils.createRealMatrix(data));
  87.     }

  88.     /**
  89.      * Create a KendallsCorrelation from a RealMatrix whose columns
  90.      * represent variables to be correlated.
  91.      *
  92.      * @param matrix matrix with columns representing variables to correlate
  93.      */
  94.     public KendallsCorrelation(RealMatrix matrix) {
  95.         correlationMatrix = computeCorrelationMatrix(matrix);
  96.     }

  97.     /**
  98.      * Returns the correlation matrix.
  99.      *
  100.      * @return correlation matrix
  101.      */
  102.     public RealMatrix getCorrelationMatrix() {
  103.         return correlationMatrix;
  104.     }

  105.     /**
  106.      * Computes the Kendall's Tau rank correlation matrix for the columns of
  107.      * the input matrix.
  108.      *
  109.      * @param matrix matrix with columns representing variables to correlate
  110.      * @return correlation matrix
  111.      */
  112.     public RealMatrix computeCorrelationMatrix(final RealMatrix matrix) {
  113.         int nVars = matrix.getColumnDimension();
  114.         RealMatrix outMatrix = new BlockRealMatrix(nVars, nVars);
  115.         for (int i = 0; i < nVars; i++) {
  116.             for (int j = 0; j < i; j++) {
  117.                 double corr = correlation(matrix.getColumn(i), matrix.getColumn(j));
  118.                 outMatrix.setEntry(i, j, corr);
  119.                 outMatrix.setEntry(j, i, corr);
  120.             }
  121.             outMatrix.setEntry(i, i, 1d);
  122.         }
  123.         return outMatrix;
  124.     }

  125.     /**
  126.      * Computes the Kendall's Tau rank correlation matrix for the columns of
  127.      * the input rectangular array.  The columns of the array represent values
  128.      * of variables to be correlated.
  129.      *
  130.      * @param matrix matrix with columns representing variables to correlate
  131.      * @return correlation matrix
  132.      */
  133.     public RealMatrix computeCorrelationMatrix(final double[][] matrix) {
  134.        return computeCorrelationMatrix(new BlockRealMatrix(matrix));
  135.     }

  136.     /**
  137.      * Computes the Kendall's Tau rank correlation coefficient between the two arrays.
  138.      *
  139.      * @param xArray first data array
  140.      * @param yArray second data array
  141.      * @return Returns Kendall's Tau rank correlation coefficient for the two arrays
  142.      * @throws MathIllegalArgumentException if the arrays lengths do not match
  143.      */
  144.     public double correlation(final double[] xArray, final double[] yArray)
  145.             throws MathIllegalArgumentException {

  146.         MathArrays.checkEqualLength(xArray, yArray);

  147.         final int n = xArray.length;
  148.         final long numPairs = sum(n - 1);

  149.         DoublePair[] pairs = new DoublePair[n];
  150.         for (int i = 0; i < n; i++) {
  151.             pairs[i] = new DoublePair(xArray[i], yArray[i]);
  152.         }

  153.         Arrays.sort(pairs, (p1, p2) -> {
  154.             int compareKey = Double.compare(p1.getFirst(), p2.getFirst());
  155.             return compareKey != 0 ? compareKey : Double.compare(p1.getSecond(), p2.getSecond());
  156.         });

  157.         long tiedXPairs = 0;
  158.         long tiedXYPairs = 0;
  159.         long consecutiveXTies = 1;
  160.         long consecutiveXYTies = 1;
  161.         DoublePair prev = pairs[0];
  162.         for (int i = 1; i < n; i++) {
  163.             final DoublePair curr = pairs[i];
  164.             if (Double.compare(curr.getFirst(), prev.getFirst()) == 0) {
  165.                 consecutiveXTies++;
  166.                 if (Double.compare(curr.getSecond(), prev.getSecond()) == 0) {
  167.                     consecutiveXYTies++;
  168.                 } else {
  169.                     tiedXYPairs += sum(consecutiveXYTies - 1);
  170.                     consecutiveXYTies = 1;
  171.                 }
  172.             } else {
  173.                 tiedXPairs += sum(consecutiveXTies - 1);
  174.                 consecutiveXTies = 1;
  175.                 tiedXYPairs += sum(consecutiveXYTies - 1);
  176.                 consecutiveXYTies = 1;
  177.             }
  178.             prev = curr;
  179.         }
  180.         tiedXPairs += sum(consecutiveXTies - 1);
  181.         tiedXYPairs += sum(consecutiveXYTies - 1);

  182.         long swaps = 0;
  183.         DoublePair[] pairsDestination = new DoublePair[n];
  184.         for (int segmentSize = 1; segmentSize < n; segmentSize <<= 1) {
  185.             for (int offset = 0; offset < n; offset += 2 * segmentSize) {
  186.                 int i = offset;
  187.                 final int iEnd = FastMath.min(i + segmentSize, n);
  188.                 int j = iEnd;
  189.                 final int jEnd = FastMath.min(j + segmentSize, n);

  190.                 int copyLocation = offset;
  191.                 while (i < iEnd || j < jEnd) {
  192.                     if (i < iEnd) {
  193.                         if (j < jEnd) {
  194.                             if (Double.compare(pairs[i].getSecond(), pairs[j].getSecond()) <= 0) {
  195.                                 pairsDestination[copyLocation] = pairs[i];
  196.                                 i++;
  197.                             } else {
  198.                                 pairsDestination[copyLocation] = pairs[j];
  199.                                 j++;
  200.                                 swaps += iEnd - i;
  201.                             }
  202.                         } else {
  203.                             pairsDestination[copyLocation] = pairs[i];
  204.                             i++;
  205.                         }
  206.                     } else {
  207.                         pairsDestination[copyLocation] = pairs[j];
  208.                         j++;
  209.                     }
  210.                     copyLocation++;
  211.                 }
  212.             }
  213.             final DoublePair[] pairsTemp = pairs;
  214.             pairs = pairsDestination;
  215.             pairsDestination = pairsTemp;
  216.         }

  217.         long tiedYPairs = 0;
  218.         long consecutiveYTies = 1;
  219.         prev = pairs[0];
  220.         for (int i = 1; i < n; i++) {
  221.             final DoublePair curr = pairs[i];
  222.             if (Double.compare(curr.getSecond(), prev.getSecond()) == 0) {
  223.                 consecutiveYTies++;
  224.             } else {
  225.                 tiedYPairs += sum(consecutiveYTies - 1);
  226.                 consecutiveYTies = 1;
  227.             }
  228.             prev = curr;
  229.         }
  230.         tiedYPairs += sum(consecutiveYTies - 1);

  231.         final long concordantMinusDiscordant = numPairs - tiedXPairs - tiedYPairs + tiedXYPairs - 2 * swaps;
  232.         final double nonTiedPairsMultiplied = (numPairs - tiedXPairs) * (double) (numPairs - tiedYPairs);
  233.         return concordantMinusDiscordant / FastMath.sqrt(nonTiedPairsMultiplied);
  234.     }

  235.     /**
  236.      * Returns the sum of the number from 1 .. n according to Gauss' summation formula:
  237.      * \[ \sum\limits_{k=1}^n k = \frac{n(n + 1)}{2} \]
  238.      *
  239.      * @param n the summation end
  240.      * @return the sum of the number from 1 to n
  241.      */
  242.     private static long sum(long n) {
  243.         return n * (n + 1) / 2l;
  244.     }

  245.     /**
  246.      * Helper data structure holding a (double, double) pair.
  247.      */
  248.     private static class DoublePair {
  249.         /** The first value */
  250.         private final double first;
  251.         /** The second value */
  252.         private final double second;

  253.         /**
  254.          * @param first first value.
  255.          * @param second second value.
  256.          */
  257.         DoublePair(double first, double second) {
  258.             this.first = first;
  259.             this.second = second;
  260.         }

  261.         /** @return the first value. */
  262.         public double getFirst() {
  263.             return first;
  264.         }

  265.         /** @return the second value. */
  266.         public double getSecond() {
  267.             return second;
  268.         }

  269.     }

  270. }