1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23 package org.hipparchus.stat.ranking;
24
25 import java.util.ArrayList;
26 import java.util.Arrays;
27 import java.util.Iterator;
28 import java.util.List;
29
30 import org.hipparchus.exception.LocalizedCoreFormats;
31 import org.hipparchus.exception.MathIllegalArgumentException;
32 import org.hipparchus.exception.MathRuntimeException;
33 import org.hipparchus.random.RandomDataGenerator;
34 import org.hipparchus.random.RandomGenerator;
35 import org.hipparchus.util.FastMath;
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75 public class NaturalRanking implements RankingAlgorithm {
76
77
78 public static final NaNStrategy DEFAULT_NAN_STRATEGY = NaNStrategy.FAILED;
79
80
81 public static final TiesStrategy DEFAULT_TIES_STRATEGY = TiesStrategy.AVERAGE;
82
83
84 private final NaNStrategy nanStrategy;
85
86
87 private final TiesStrategy tiesStrategy;
88
89
90 private final RandomDataGenerator randomData;
91
92
93
94
95 public NaturalRanking() {
96 super();
97 tiesStrategy = DEFAULT_TIES_STRATEGY;
98 nanStrategy = DEFAULT_NAN_STRATEGY;
99 randomData = null;
100 }
101
102
103
104
105
106
107 public NaturalRanking(TiesStrategy tiesStrategy) {
108 super();
109 this.tiesStrategy = tiesStrategy;
110 nanStrategy = DEFAULT_NAN_STRATEGY;
111 randomData = new RandomDataGenerator();
112 }
113
114
115
116
117
118
119 public NaturalRanking(NaNStrategy nanStrategy) {
120 super();
121 this.nanStrategy = nanStrategy;
122 tiesStrategy = DEFAULT_TIES_STRATEGY;
123 randomData = null;
124 }
125
126
127
128
129
130
131
132 public NaturalRanking(NaNStrategy nanStrategy, TiesStrategy tiesStrategy) {
133 super();
134 this.nanStrategy = nanStrategy;
135 this.tiesStrategy = tiesStrategy;
136 randomData = new RandomDataGenerator();
137 }
138
139
140
141
142
143
144
145 public NaturalRanking(RandomGenerator randomGenerator) {
146 super();
147 this.tiesStrategy = TiesStrategy.RANDOM;
148 nanStrategy = DEFAULT_NAN_STRATEGY;
149 randomData = RandomDataGenerator.of(randomGenerator);
150 }
151
152
153
154
155
156
157
158
159
160 public NaturalRanking(NaNStrategy nanStrategy,
161 RandomGenerator randomGenerator) {
162 super();
163 this.nanStrategy = nanStrategy;
164 this.tiesStrategy = TiesStrategy.RANDOM;
165 randomData = RandomDataGenerator.of(randomGenerator);
166 }
167
168
169
170
171
172
173 public NaNStrategy getNanStrategy() {
174 return nanStrategy;
175 }
176
177
178
179
180
181
182 public TiesStrategy getTiesStrategy() {
183 return tiesStrategy;
184 }
185
186
187
188
189
190
191
192
193
194
195
196 @Override
197 public double[] rank(double[] data) {
198
199
200 IntDoublePair[] ranks = new IntDoublePair[data.length];
201 for (int i = 0; i < data.length; i++) {
202 ranks[i] = new IntDoublePair(data[i], i);
203 }
204
205
206 List<Integer> nanPositions = null;
207 switch (nanStrategy) {
208 case MAXIMAL:
209 recodeNaNs(ranks, Double.POSITIVE_INFINITY);
210 break;
211 case MINIMAL:
212 recodeNaNs(ranks, Double.NEGATIVE_INFINITY);
213 break;
214 case REMOVED:
215 ranks = removeNaNs(ranks);
216 break;
217 case FIXED:
218 nanPositions = getNanPositions(ranks);
219 break;
220 case FAILED:
221 nanPositions = getNanPositions(ranks);
222 if (!nanPositions.isEmpty()) {
223 throw new MathIllegalArgumentException(LocalizedCoreFormats.NAN_NOT_ALLOWED);
224 }
225 break;
226 default:
227 throw MathRuntimeException.createInternalError();
228 }
229
230
231 Arrays.sort(ranks, (p1, p2) -> Double.compare(p1.value, p2.value));
232
233
234
235 double[] out = new double[ranks.length];
236 int pos = 1;
237 out[ranks[0].getPosition()] = pos;
238 List<Integer> tiesTrace = new ArrayList<>();
239 tiesTrace.add(ranks[0].getPosition());
240 for (int i = 1; i < ranks.length; i++) {
241 if (Double.compare(ranks[i].getValue(), ranks[i - 1].getValue()) > 0) {
242
243 pos = i + 1;
244 if (tiesTrace.size() > 1) {
245 resolveTie(out, tiesTrace);
246 }
247 tiesTrace = new ArrayList<>();
248 tiesTrace.add(ranks[i].getPosition());
249 } else {
250
251 tiesTrace.add(ranks[i].getPosition());
252 }
253 out[ranks[i].getPosition()] = pos;
254 }
255 if (tiesTrace.size() > 1) {
256 resolveTie(out, tiesTrace);
257 }
258 if (nanStrategy == NaNStrategy.FIXED) {
259 restoreNaNs(out, nanPositions);
260 }
261 return out;
262 }
263
264
265
266
267
268
269
270
271 private IntDoublePair[] removeNaNs(IntDoublePair[] ranks) {
272 if (!containsNaNs(ranks)) {
273 return ranks;
274 }
275 IntDoublePair[] outRanks = new IntDoublePair[ranks.length];
276 int j = 0;
277 for (int i = 0; i < ranks.length; i++) {
278 if (Double.isNaN(ranks[i].getValue())) {
279
280 for (int k = i + 1; k < ranks.length; k++) {
281 ranks[k] = new IntDoublePair(
282 ranks[k].getValue(), ranks[k].getPosition() - 1);
283 }
284 } else {
285 outRanks[j] = new IntDoublePair(
286 ranks[i].getValue(), ranks[i].getPosition());
287 j++;
288 }
289 }
290 IntDoublePair[] returnRanks = new IntDoublePair[j];
291 System.arraycopy(outRanks, 0, returnRanks, 0, j);
292 return returnRanks;
293 }
294
295
296
297
298
299
300
301 private void recodeNaNs(IntDoublePair[] ranks, double value) {
302 for (int i = 0; i < ranks.length; i++) {
303 if (Double.isNaN(ranks[i].getValue())) {
304 ranks[i] = new IntDoublePair(
305 value, ranks[i].getPosition());
306 }
307 }
308 }
309
310
311
312
313
314
315
316 private boolean containsNaNs(IntDoublePair[] ranks) {
317 for (int i = 0; i < ranks.length; i++) {
318 if (Double.isNaN(ranks[i].getValue())) {
319 return true;
320 }
321 }
322 return false;
323 }
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339 private void resolveTie(double[] ranks, List<Integer> tiesTrace) {
340
341
342 final double c = ranks[tiesTrace.get(0)];
343
344
345 final int length = tiesTrace.size();
346
347 switch (tiesStrategy) {
348 case AVERAGE:
349 fill(ranks, tiesTrace, (2 * c + length - 1) / 2d);
350 break;
351 case MAXIMUM:
352 fill(ranks, tiesTrace, c + length - 1);
353 break;
354 case MINIMUM:
355 fill(ranks, tiesTrace, c);
356 break;
357 case RANDOM:
358 Iterator<Integer> iterator = tiesTrace.iterator();
359 long f = FastMath.round(c);
360 while (iterator.hasNext()) {
361
362 ranks[iterator.next()] =
363 randomData.nextLong(f, f + length - 1);
364 }
365 break;
366 case SEQUENTIAL:
367
368 iterator = tiesTrace.iterator();
369 f = FastMath.round(c);
370 int i = 0;
371 while (iterator.hasNext()) {
372 ranks[iterator.next()] = f + i++;
373 }
374 break;
375 default:
376 throw MathRuntimeException.createInternalError();
377 }
378 }
379
380
381
382
383
384
385
386
387 private void fill(double[] data, List<Integer> tiesTrace, double value) {
388 Iterator<Integer> iterator = tiesTrace.iterator();
389 while (iterator.hasNext()) {
390 data[iterator.next()] = value;
391 }
392 }
393
394
395
396
397
398
399
400 private void restoreNaNs(double[] ranks, List<Integer> nanPositions) {
401 if (nanPositions.isEmpty()) {
402 return;
403 }
404 Iterator<Integer> iterator = nanPositions.iterator();
405 while (iterator.hasNext()) {
406 ranks[iterator.next().intValue()] = Double.NaN;
407 }
408
409 }
410
411
412
413
414
415
416
417 private List<Integer> getNanPositions(IntDoublePair[] ranks) {
418 ArrayList<Integer> out = new ArrayList<>();
419 for (int i = 0; i < ranks.length; i++) {
420 if (Double.isNaN(ranks[i].getValue())) {
421 out.add(Integer.valueOf(i));
422 }
423 }
424 return out;
425 }
426
427
428
429
430 private static class IntDoublePair {
431
432
433 private final double value;
434
435
436 private final int position;
437
438
439
440
441
442
443 IntDoublePair(double value, int position) {
444 this.value = value;
445 this.position = position;
446 }
447
448
449
450
451
452 public double getValue() {
453 return value;
454 }
455
456
457
458
459
460 public int getPosition() {
461 return position;
462 }
463 }
464 }