1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * https://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 /*
19 * This is not the original file distributed by the Apache Software Foundation
20 * It has been modified by the Hipparchus project
21 */
22 package org.hipparchus.stat.correlation;
23
24 import java.util.Arrays;
25
26 import org.hipparchus.exception.MathIllegalArgumentException;
27 import org.hipparchus.linear.BlockRealMatrix;
28 import org.hipparchus.linear.MatrixUtils;
29 import org.hipparchus.linear.RealMatrix;
30 import org.hipparchus.util.FastMath;
31 import org.hipparchus.util.MathArrays;
32
33 /**
34 * Implementation of Kendall's Tau-b rank correlation.
35 * <p>
36 * A pair of observations (x<sub>1</sub>, y<sub>1</sub>) and
37 * (x<sub>2</sub>, y<sub>2</sub>) are considered <i>concordant</i> if
38 * x<sub>1</sub> < x<sub>2</sub> and y<sub>1</sub> < y<sub>2</sub>
39 * or x<sub>2</sub> < x<sub>1</sub> and y<sub>2</sub> < y<sub>1</sub>.
40 * The pair is <i>discordant</i> if x<sub>1</sub> < x<sub>2</sub> and
41 * y<sub>2</sub> < y<sub>1</sub> or x<sub>2</sub> < x<sub>1</sub> and
42 * y<sub>1</sub> < y<sub>2</sub>. If either x<sub>1</sub> = x<sub>2</sub>
43 * or y<sub>1</sub> = y<sub>2</sub>, the pair is neither concordant nor
44 * discordant.
45 * <p>
46 * Kendall's Tau-b is defined as:
47 * \[
48 * \tau_b = \frac{n_c - n_d}{\sqrt{(n_0 - n_1) (n_0 - n_2)}}
49 * \]
50 * <p>
51 * where:
52 * <ul>
53 * <li>n<sub>0</sub> = n * (n - 1) / 2</li>
54 * <li>n<sub>c</sub> = Number of concordant pairs</li>
55 * <li>n<sub>d</sub> = Number of discordant pairs</li>
56 * <li>n<sub>1</sub> = sum of t<sub>i</sub> * (t<sub>i</sub> - 1) / 2 for all i</li>
57 * <li>n<sub>2</sub> = sum of u<sub>j</sub> * (u<sub>j</sub> - 1) / 2 for all j</li>
58 * <li>t<sub>i</sub> = Number of tied values in the i<sup>th</sup> group of ties in x</li>
59 * <li>u<sub>j</sub> = Number of tied values in the j<sup>th</sup> group of ties in y</li>
60 * </ul>
61 * <p>
62 * This implementation uses the O(n log n) algorithm described in
63 * William R. Knight's 1966 paper "A Computer Method for Calculating
64 * Kendall's Tau with Ungrouped Data" in the Journal of the American
65 * Statistical Association.
66 *
67 * @see <a href="http://en.wikipedia.org/wiki/Kendall_tau_rank_correlation_coefficient">
68 * Kendall tau rank correlation coefficient (Wikipedia)</a>
69 * @see <a href="http://www.jstor.org/stable/2282833">A Computer
70 * Method for Calculating Kendall's Tau with Ungrouped Data</a>
71 */
72 public class KendallsCorrelation {
73
74 /** correlation matrix */
75 private final RealMatrix correlationMatrix;
76
77 /**
78 * Create a KendallsCorrelation instance without data.
79 */
80 public KendallsCorrelation() {
81 correlationMatrix = null;
82 }
83
84 /**
85 * Create a KendallsCorrelation from a rectangular array
86 * whose columns represent values of variables to be correlated.
87 *
88 * @param data rectangular array with columns representing variables
89 * @throws IllegalArgumentException if the input data array is not
90 * rectangular with at least two rows and two columns.
91 */
92 public KendallsCorrelation(double[][] data) {
93 this(MatrixUtils.createRealMatrix(data));
94 }
95
96 /**
97 * Create a KendallsCorrelation from a RealMatrix whose columns
98 * represent variables to be correlated.
99 *
100 * @param matrix matrix with columns representing variables to correlate
101 */
102 public KendallsCorrelation(RealMatrix matrix) {
103 correlationMatrix = computeCorrelationMatrix(matrix);
104 }
105
106 /**
107 * Returns the correlation matrix.
108 *
109 * @return correlation matrix
110 */
111 public RealMatrix getCorrelationMatrix() {
112 return correlationMatrix;
113 }
114
115 /**
116 * Computes the Kendall's Tau rank correlation matrix for the columns of
117 * the input matrix.
118 *
119 * @param matrix matrix with columns representing variables to correlate
120 * @return correlation matrix
121 */
122 public RealMatrix computeCorrelationMatrix(final RealMatrix matrix) {
123 int nVars = matrix.getColumnDimension();
124 RealMatrix outMatrix = new BlockRealMatrix(nVars, nVars);
125 for (int i = 0; i < nVars; i++) {
126 for (int j = 0; j < i; j++) {
127 double corr = correlation(matrix.getColumn(i), matrix.getColumn(j));
128 outMatrix.setEntry(i, j, corr);
129 outMatrix.setEntry(j, i, corr);
130 }
131 outMatrix.setEntry(i, i, 1d);
132 }
133 return outMatrix;
134 }
135
136 /**
137 * Computes the Kendall's Tau rank correlation matrix for the columns of
138 * the input rectangular array. The columns of the array represent values
139 * of variables to be correlated.
140 *
141 * @param matrix matrix with columns representing variables to correlate
142 * @return correlation matrix
143 */
144 public RealMatrix computeCorrelationMatrix(final double[][] matrix) {
145 return computeCorrelationMatrix(new BlockRealMatrix(matrix));
146 }
147
148 /**
149 * Computes the Kendall's Tau rank correlation coefficient between the two arrays.
150 *
151 * @param xArray first data array
152 * @param yArray second data array
153 * @return Returns Kendall's Tau rank correlation coefficient for the two arrays
154 * @throws MathIllegalArgumentException if the arrays lengths do not match
155 */
156 public double correlation(final double[] xArray, final double[] yArray)
157 throws MathIllegalArgumentException {
158
159 MathArrays.checkEqualLength(xArray, yArray);
160
161 final int n = xArray.length;
162 final long numPairs = sum(n - 1);
163
164 DoublePair[] pairs = new DoublePair[n];
165 for (int i = 0; i < n; i++) {
166 pairs[i] = new DoublePair(xArray[i], yArray[i]);
167 }
168
169 Arrays.sort(pairs, (p1, p2) -> {
170 int compareKey = Double.compare(p1.getFirst(), p2.getFirst());
171 return compareKey != 0 ? compareKey : Double.compare(p1.getSecond(), p2.getSecond());
172 });
173
174 long tiedXPairs = 0;
175 long tiedXYPairs = 0;
176 long consecutiveXTies = 1;
177 long consecutiveXYTies = 1;
178 DoublePair prev = pairs[0];
179 for (int i = 1; i < n; i++) {
180 final DoublePair curr = pairs[i];
181 if (Double.compare(curr.getFirst(), prev.getFirst()) == 0) {
182 consecutiveXTies++;
183 if (Double.compare(curr.getSecond(), prev.getSecond()) == 0) {
184 consecutiveXYTies++;
185 } else {
186 tiedXYPairs += sum(consecutiveXYTies - 1);
187 consecutiveXYTies = 1;
188 }
189 } else {
190 tiedXPairs += sum(consecutiveXTies - 1);
191 consecutiveXTies = 1;
192 tiedXYPairs += sum(consecutiveXYTies - 1);
193 consecutiveXYTies = 1;
194 }
195 prev = curr;
196 }
197 tiedXPairs += sum(consecutiveXTies - 1);
198 tiedXYPairs += sum(consecutiveXYTies - 1);
199
200 long swaps = 0;
201 DoublePair[] pairsDestination = new DoublePair[n];
202 for (int segmentSize = 1; segmentSize < n; segmentSize <<= 1) {
203 for (int offset = 0; offset < n; offset += 2 * segmentSize) {
204 int i = offset;
205 final int iEnd = FastMath.min(i + segmentSize, n);
206 int j = iEnd;
207 final int jEnd = FastMath.min(j + segmentSize, n);
208
209 int copyLocation = offset;
210 while (i < iEnd || j < jEnd) {
211 if (i < iEnd) {
212 if (j < jEnd) {
213 if (Double.compare(pairs[i].getSecond(), pairs[j].getSecond()) <= 0) {
214 pairsDestination[copyLocation] = pairs[i];
215 i++;
216 } else {
217 pairsDestination[copyLocation] = pairs[j];
218 j++;
219 swaps += iEnd - i;
220 }
221 } else {
222 pairsDestination[copyLocation] = pairs[i];
223 i++;
224 }
225 } else {
226 pairsDestination[copyLocation] = pairs[j];
227 j++;
228 }
229 copyLocation++;
230 }
231 }
232 final DoublePair[] pairsTemp = pairs;
233 pairs = pairsDestination;
234 pairsDestination = pairsTemp;
235 }
236
237 long tiedYPairs = 0;
238 long consecutiveYTies = 1;
239 prev = pairs[0];
240 for (int i = 1; i < n; i++) {
241 final DoublePair curr = pairs[i];
242 if (Double.compare(curr.getSecond(), prev.getSecond()) == 0) {
243 consecutiveYTies++;
244 } else {
245 tiedYPairs += sum(consecutiveYTies - 1);
246 consecutiveYTies = 1;
247 }
248 prev = curr;
249 }
250 tiedYPairs += sum(consecutiveYTies - 1);
251
252 final long concordantMinusDiscordant = numPairs - tiedXPairs - tiedYPairs + tiedXYPairs - 2 * swaps;
253 final double nonTiedPairsMultiplied = (numPairs - tiedXPairs) * (double) (numPairs - tiedYPairs);
254 return concordantMinusDiscordant / FastMath.sqrt(nonTiedPairsMultiplied);
255 }
256
257 /**
258 * Returns the sum of the number from 1 .. n according to Gauss' summation formula:
259 * \[ \sum\limits_{k=1}^n k = \frac{n(n + 1)}{2} \]
260 *
261 * @param n the summation end
262 * @return the sum of the number from 1 to n
263 */
264 private static long sum(long n) {
265 return n * (n + 1) / 2l;
266 }
267
268 /**
269 * Helper data structure holding a (double, double) pair.
270 */
271 private static class DoublePair {
272 /** The first value */
273 private final double first;
274 /** The second value */
275 private final double second;
276
277 /**
278 * @param first first value.
279 * @param second second value.
280 */
281 DoublePair(double first, double second) {
282 this.first = first;
283 this.second = second;
284 }
285
286 /** @return the first value. */
287 public double getFirst() {
288 return first;
289 }
290
291 /** @return the second value. */
292 public double getSecond() {
293 return second;
294 }
295
296 }
297
298 }