1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22 package org.hipparchus.stat.correlation;
23
24 import java.util.Arrays;
25
26 import org.hipparchus.exception.MathIllegalArgumentException;
27 import org.hipparchus.linear.BlockRealMatrix;
28 import org.hipparchus.linear.MatrixUtils;
29 import org.hipparchus.linear.RealMatrix;
30 import org.hipparchus.util.FastMath;
31 import org.hipparchus.util.MathArrays;
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72 public class KendallsCorrelation {
73
74
75 private final RealMatrix correlationMatrix;
76
77
78
79
80 public KendallsCorrelation() {
81 correlationMatrix = null;
82 }
83
84
85
86
87
88
89
90
91
92 public KendallsCorrelation(double[][] data) {
93 this(MatrixUtils.createRealMatrix(data));
94 }
95
96
97
98
99
100
101
102 public KendallsCorrelation(RealMatrix matrix) {
103 correlationMatrix = computeCorrelationMatrix(matrix);
104 }
105
106
107
108
109
110
111 public RealMatrix getCorrelationMatrix() {
112 return correlationMatrix;
113 }
114
115
116
117
118
119
120
121
122 public RealMatrix computeCorrelationMatrix(final RealMatrix matrix) {
123 int nVars = matrix.getColumnDimension();
124 RealMatrix outMatrix = new BlockRealMatrix(nVars, nVars);
125 for (int i = 0; i < nVars; i++) {
126 for (int j = 0; j < i; j++) {
127 double corr = correlation(matrix.getColumn(i), matrix.getColumn(j));
128 outMatrix.setEntry(i, j, corr);
129 outMatrix.setEntry(j, i, corr);
130 }
131 outMatrix.setEntry(i, i, 1d);
132 }
133 return outMatrix;
134 }
135
136
137
138
139
140
141
142
143
144 public RealMatrix computeCorrelationMatrix(final double[][] matrix) {
145 return computeCorrelationMatrix(new BlockRealMatrix(matrix));
146 }
147
148
149
150
151
152
153
154
155
156 public double correlation(final double[] xArray, final double[] yArray)
157 throws MathIllegalArgumentException {
158
159 MathArrays.checkEqualLength(xArray, yArray);
160
161 final int n = xArray.length;
162 final long numPairs = sum(n - 1);
163
164 DoublePair[] pairs = new DoublePair[n];
165 for (int i = 0; i < n; i++) {
166 pairs[i] = new DoublePair(xArray[i], yArray[i]);
167 }
168
169 Arrays.sort(pairs, (p1, p2) -> {
170 int compareKey = Double.compare(p1.getFirst(), p2.getFirst());
171 return compareKey != 0 ? compareKey : Double.compare(p1.getSecond(), p2.getSecond());
172 });
173
174 long tiedXPairs = 0;
175 long tiedXYPairs = 0;
176 long consecutiveXTies = 1;
177 long consecutiveXYTies = 1;
178 DoublePair prev = pairs[0];
179 for (int i = 1; i < n; i++) {
180 final DoublePair curr = pairs[i];
181 if (Double.compare(curr.getFirst(), prev.getFirst()) == 0) {
182 consecutiveXTies++;
183 if (Double.compare(curr.getSecond(), prev.getSecond()) == 0) {
184 consecutiveXYTies++;
185 } else {
186 tiedXYPairs += sum(consecutiveXYTies - 1);
187 consecutiveXYTies = 1;
188 }
189 } else {
190 tiedXPairs += sum(consecutiveXTies - 1);
191 consecutiveXTies = 1;
192 tiedXYPairs += sum(consecutiveXYTies - 1);
193 consecutiveXYTies = 1;
194 }
195 prev = curr;
196 }
197 tiedXPairs += sum(consecutiveXTies - 1);
198 tiedXYPairs += sum(consecutiveXYTies - 1);
199
200 long swaps = 0;
201 DoublePair[] pairsDestination = new DoublePair[n];
202 for (int segmentSize = 1; segmentSize < n; segmentSize <<= 1) {
203 for (int offset = 0; offset < n; offset += 2 * segmentSize) {
204 int i = offset;
205 final int iEnd = FastMath.min(i + segmentSize, n);
206 int j = iEnd;
207 final int jEnd = FastMath.min(j + segmentSize, n);
208
209 int copyLocation = offset;
210 while (i < iEnd || j < jEnd) {
211 if (i < iEnd) {
212 if (j < jEnd) {
213 if (Double.compare(pairs[i].getSecond(), pairs[j].getSecond()) <= 0) {
214 pairsDestination[copyLocation] = pairs[i];
215 i++;
216 } else {
217 pairsDestination[copyLocation] = pairs[j];
218 j++;
219 swaps += iEnd - i;
220 }
221 } else {
222 pairsDestination[copyLocation] = pairs[i];
223 i++;
224 }
225 } else {
226 pairsDestination[copyLocation] = pairs[j];
227 j++;
228 }
229 copyLocation++;
230 }
231 }
232 final DoublePair[] pairsTemp = pairs;
233 pairs = pairsDestination;
234 pairsDestination = pairsTemp;
235 }
236
237 long tiedYPairs = 0;
238 long consecutiveYTies = 1;
239 prev = pairs[0];
240 for (int i = 1; i < n; i++) {
241 final DoublePair curr = pairs[i];
242 if (Double.compare(curr.getSecond(), prev.getSecond()) == 0) {
243 consecutiveYTies++;
244 } else {
245 tiedYPairs += sum(consecutiveYTies - 1);
246 consecutiveYTies = 1;
247 }
248 prev = curr;
249 }
250 tiedYPairs += sum(consecutiveYTies - 1);
251
252 final long concordantMinusDiscordant = numPairs - tiedXPairs - tiedYPairs + tiedXYPairs - 2 * swaps;
253 final double nonTiedPairsMultiplied = (numPairs - tiedXPairs) * (double) (numPairs - tiedYPairs);
254 return concordantMinusDiscordant / FastMath.sqrt(nonTiedPairsMultiplied);
255 }
256
257
258
259
260
261
262
263
264 private static long sum(long n) {
265 return n * (n + 1) / 2l;
266 }
267
268
269
270
271 private static class DoublePair {
272
273 private final double first;
274
275 private final double second;
276
277
278
279
280
281 DoublePair(double first, double second) {
282 this.first = first;
283 this.second = second;
284 }
285
286
287 public double getFirst() {
288 return first;
289 }
290
291
292 public double getSecond() {
293 return second;
294 }
295
296 }
297
298 }