1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22 package org.hipparchus.random;
23
24 import java.io.Serializable;
25 import java.util.ArrayList;
26 import java.util.Arrays;
27 import java.util.Collection;
28 import java.util.List;
29 import java.util.Map;
30 import java.util.concurrent.ConcurrentHashMap;
31
32 import org.hipparchus.distribution.EnumeratedDistribution;
33 import org.hipparchus.distribution.IntegerDistribution;
34 import org.hipparchus.distribution.RealDistribution;
35 import org.hipparchus.distribution.continuous.BetaDistribution;
36 import org.hipparchus.distribution.continuous.EnumeratedRealDistribution;
37 import org.hipparchus.distribution.continuous.ExponentialDistribution;
38 import org.hipparchus.distribution.continuous.GammaDistribution;
39 import org.hipparchus.distribution.continuous.LogNormalDistribution;
40 import org.hipparchus.distribution.continuous.NormalDistribution;
41 import org.hipparchus.distribution.continuous.UniformRealDistribution;
42 import org.hipparchus.distribution.discrete.EnumeratedIntegerDistribution;
43 import org.hipparchus.distribution.discrete.PoissonDistribution;
44 import org.hipparchus.distribution.discrete.UniformIntegerDistribution;
45 import org.hipparchus.distribution.discrete.ZipfDistribution;
46 import org.hipparchus.exception.LocalizedCoreFormats;
47 import org.hipparchus.exception.MathIllegalArgumentException;
48 import org.hipparchus.util.CombinatoricsUtils;
49 import org.hipparchus.util.FastMath;
50 import org.hipparchus.util.MathArrays;
51 import org.hipparchus.util.MathUtils;
52 import org.hipparchus.util.Pair;
53 import org.hipparchus.util.Precision;
54 import org.hipparchus.util.ResizableDoubleArray;
55
56
57
58
59 public class RandomDataGenerator extends ForwardingRandomGenerator
60 implements RandomGenerator, Serializable {
61
62
63 private static final long serialVersionUID = 20160529L;
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78 private static final double[] EXPONENTIAL_SA_QI;
79
80
81 private static final Map<Class<? extends RealDistribution>, RealDistributionSampler> CONTINUOUS_SAMPLERS = new ConcurrentHashMap<>();
82
83 private static final Map<Class<? extends IntegerDistribution>, IntegerDistributionSampler> DISCRETE_SAMPLERS = new ConcurrentHashMap<>();
84
85
86 private static final RealDistributionSampler DEFAULT_REAL_SAMPLER =
87 (generator, dist) -> dist.inverseCumulativeProbability(generator.nextDouble());
88
89
90 private static final IntegerDistributionSampler DEFAULT_INTEGER_SAMPLER =
91 (generator, dist) -> dist.inverseCumulativeProbability(generator.nextDouble());
92
93
94 private final RandomGenerator randomGenerator;
95
96
97 private transient ZipfRejectionInversionSampler zipfSampler;
98
99
100
101
102 @FunctionalInterface
103 private interface RealDistributionSampler {
104
105
106
107
108
109
110
111 double nextSample(RandomDataGenerator generator, RealDistribution distribution);
112 }
113
114
115
116
117 @FunctionalInterface
118 private interface IntegerDistributionSampler {
119
120
121
122
123
124
125
126 int nextSample(RandomDataGenerator generator, IntegerDistribution distribution);
127 }
128
129
130
131
132 static {
133
134
135
136
137 final double LN2 = FastMath.log(2);
138 double qi = 0;
139 int i = 1;
140
141
142
143
144
145
146
147 final ResizableDoubleArray ra = new ResizableDoubleArray(20);
148
149 while (qi < 1) {
150 qi += FastMath.pow(LN2, i) / CombinatoricsUtils.factorial(i);
151 ra.addElement(qi);
152 ++i;
153 }
154
155 EXPONENTIAL_SA_QI = ra.getElements();
156
157
158
159 CONTINUOUS_SAMPLERS.put(BetaDistribution.class,
160 (generator, dist) -> {
161 BetaDistribution beta = (BetaDistribution) dist;
162 return generator.nextBeta(beta.getAlpha(), beta.getBeta());
163 });
164
165 CONTINUOUS_SAMPLERS.put(ExponentialDistribution.class,
166 (generator, dist) -> generator.nextExponential(dist.getNumericalMean()));
167
168 CONTINUOUS_SAMPLERS.put(GammaDistribution.class,
169 (generator, dist) -> {
170 GammaDistribution gamma = (GammaDistribution) dist;
171 return generator.nextGamma(gamma.getShape(), gamma.getScale());
172 });
173
174 CONTINUOUS_SAMPLERS.put(NormalDistribution.class,
175 (generator, dist) -> {
176 NormalDistribution normal = (NormalDistribution) dist;
177 return generator.nextNormal(normal.getMean(),
178 normal.getStandardDeviation());
179 });
180
181 CONTINUOUS_SAMPLERS.put(LogNormalDistribution.class,
182 (generator, dist) -> {
183 LogNormalDistribution logNormal = (LogNormalDistribution) dist;
184 return generator.nextLogNormal(logNormal.getShape(),
185 logNormal.getLocation());
186 });
187
188 CONTINUOUS_SAMPLERS.put(UniformRealDistribution.class,
189 (generator, dist) -> generator.nextUniform(dist.getSupportLowerBound(),
190 dist.getSupportUpperBound()));
191
192 CONTINUOUS_SAMPLERS.put(EnumeratedRealDistribution.class,
193 (generator, dist) -> {
194 final EnumeratedRealDistribution edist =
195 (EnumeratedRealDistribution) dist;
196 EnumeratedDistributionSampler<Double> sampler =
197 generator.new EnumeratedDistributionSampler<Double>(edist.getPmf());
198 return sampler.sample();
199 });
200
201
202
203 DISCRETE_SAMPLERS.put(PoissonDistribution.class,
204 (generator, dist) -> generator.nextPoisson(dist.getNumericalMean()));
205
206 DISCRETE_SAMPLERS.put(UniformIntegerDistribution.class,
207 (generator, dist) -> generator.nextInt(dist.getSupportLowerBound(),
208 dist.getSupportUpperBound()));
209 DISCRETE_SAMPLERS.put(ZipfDistribution.class,
210 (generator, dist) -> {
211 ZipfDistribution zipfDist = (ZipfDistribution) dist;
212 return generator.nextZipf(zipfDist.getNumberOfElements(),
213 zipfDist.getExponent());
214 });
215
216 DISCRETE_SAMPLERS.put(EnumeratedIntegerDistribution.class,
217 (generator, dist) -> {
218 final EnumeratedIntegerDistribution edist =
219 (EnumeratedIntegerDistribution) dist;
220 EnumeratedDistributionSampler<Integer> sampler =
221 generator.new EnumeratedDistributionSampler<Integer>(edist.getPmf());
222 return sampler.sample();
223 });
224 }
225
226
227
228
229 public RandomDataGenerator() {
230 this(new Well19937c());
231 }
232
233
234
235
236
237
238
239 public RandomDataGenerator(long seed) {
240 this(new Well19937c(seed));
241 }
242
243
244
245
246
247
248
249 private RandomDataGenerator(RandomGenerator randomGenerator) {
250 MathUtils.checkNotNull(randomGenerator);
251 this.randomGenerator = randomGenerator;
252 }
253
254
255
256
257
258
259
260
261
262 public static RandomDataGenerator of(RandomGenerator randomGenerator) {
263 return new RandomDataGenerator(randomGenerator);
264 }
265
266
267 @Override
268 protected RandomGenerator delegate() {
269 return randomGenerator;
270 }
271
272
273
274
275
276
277
278
279
280 public double nextBeta(double alpha, double beta) {
281 return ChengBetaSampler.sample(randomGenerator, alpha, beta);
282 }
283
284
285
286
287
288
289
290 public double nextExponential(double mean) {
291 if (mean <= 0) {
292 throw new MathIllegalArgumentException(LocalizedCoreFormats.MEAN, mean);
293 }
294
295 double a = 0;
296 double u = randomGenerator.nextDouble();
297
298
299 while (u < 0.5) {
300 a += EXPONENTIAL_SA_QI[0];
301 u *= 2;
302 }
303
304
305 u += u - 1;
306
307
308 if (u <= EXPONENTIAL_SA_QI[0]) {
309 return mean * (a + u);
310 }
311
312
313 int i = 0;
314 double u2 = randomGenerator.nextDouble();
315 double umin = u2;
316
317
318 do {
319 ++i;
320 u2 = randomGenerator.nextDouble();
321
322 if (u2 < umin) {
323 umin = u2;
324 }
325
326
327 } while (u > EXPONENTIAL_SA_QI[i]);
328
329 return mean * (a + umin * EXPONENTIAL_SA_QI[0]);
330 }
331
332
333
334
335
336
337
338
339 public double nextGamma(double shape, double scale) {
340 if (shape < 1) {
341
342
343 while (true) {
344
345 final double u = randomGenerator.nextDouble();
346 final double bGS = 1 + shape / FastMath.E;
347 final double p = bGS * u;
348
349 if (p <= 1) {
350
351
352 final double x = FastMath.pow(p, 1 / shape);
353 final double u2 = randomGenerator.nextDouble();
354
355 if (u2 > FastMath.exp(-x)) {
356
357 continue;
358 } else {
359 return scale * x;
360 }
361 } else {
362
363
364 final double x = -1 * FastMath.log((bGS - p) / shape);
365 final double u2 = randomGenerator.nextDouble();
366
367 if (u2 > FastMath.pow(x, shape - 1)) {
368
369 continue;
370 } else {
371 return scale * x;
372 }
373 }
374 }
375 }
376
377
378
379 final double d = shape - 0.333333333333333333;
380 final double c = 1 / (3 * FastMath.sqrt(d));
381
382 while (true) {
383 final double x = randomGenerator.nextGaussian();
384 final double v = (1 + c * x) * (1 + c * x) * (1 + c * x);
385
386 if (v <= 0) {
387 continue;
388 }
389
390 final double x2 = x * x;
391 final double u = randomGenerator.nextDouble();
392
393
394 if (u < 1 - 0.0331 * x2 * x2) {
395 return scale * d * v;
396 }
397
398 if (FastMath.log(u) < 0.5 * x2 + d * (1 - v + FastMath.log(v))) {
399 return scale * d * v;
400 }
401 }
402 }
403
404
405
406
407
408
409
410
411 public double nextNormal(double mean, double standardDeviation) {
412 if (standardDeviation <= 0) {
413 throw new MathIllegalArgumentException (LocalizedCoreFormats.NUMBER_TOO_SMALL, standardDeviation, 0);
414 }
415 return standardDeviation * nextGaussian() + mean;
416 }
417
418
419
420
421
422
423
424
425 public double nextLogNormal(double shape, double scale) {
426 if (shape <= 0) {
427 throw new MathIllegalArgumentException (LocalizedCoreFormats.NUMBER_TOO_SMALL, shape, 0);
428 }
429 return FastMath.exp(scale + shape * nextGaussian());
430 }
431
432
433
434
435
436
437
438
439 public int nextPoisson(double mean) {
440 if (mean <= 0) {
441 throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL, mean, 0);
442 }
443 final double pivot = 40.0d;
444 if (mean < pivot) {
445 double p = FastMath.exp(-mean);
446 long n = 0;
447 double r = 1.0d;
448 double rnd;
449
450 while (n < 1000 * mean) {
451 rnd = randomGenerator.nextDouble();
452 r *= rnd;
453 if (r >= p) {
454 n++;
455 } else {
456 return (int) FastMath.min(n, Integer.MAX_VALUE);
457 }
458 }
459 return (int) FastMath.min(n, Integer.MAX_VALUE);
460 } else {
461 final double lambda = FastMath.floor(mean);
462 final double lambdaFractional = mean - lambda;
463 final double logLambda = FastMath.log(lambda);
464 final double logLambdaFactorial = CombinatoricsUtils.factorialLog((int) lambda);
465 final long y2 = lambdaFractional < Double.MIN_VALUE ? 0 : nextPoisson(lambdaFractional);
466 final double delta = FastMath.sqrt(lambda * FastMath.log(32 * lambda / FastMath.PI + 1));
467 final double halfDelta = delta / 2;
468 final double twolpd = 2 * lambda + delta;
469 final double a1 = FastMath.sqrt(FastMath.PI * twolpd) * FastMath.exp(1 / (8 * lambda));
470 final double a2 = (twolpd / delta) * FastMath.exp(-delta * (1 + delta) / twolpd);
471 final double aSum = a1 + a2 + 1;
472 final double p1 = a1 / aSum;
473 final double p2 = a2 / aSum;
474 final double c1 = 1 / (8 * lambda);
475
476 double x;
477 double y = 0;
478 double v;
479 int a;
480 double t;
481 double qr;
482 double qa;
483 for (;;) {
484 final double u = randomGenerator.nextDouble();
485 if (u <= p1) {
486 final double n = randomGenerator.nextGaussian();
487 x = n * FastMath.sqrt(lambda + halfDelta) - 0.5d;
488 if (x > delta || x < -lambda) {
489 continue;
490 }
491 y = x < 0 ? FastMath.floor(x) : FastMath.ceil(x);
492 final double e = nextExponential(1);
493 v = -e - (n * n / 2) + c1;
494 } else {
495 if (u > p1 + p2) {
496 y = lambda;
497 break;
498 } else {
499 x = delta + (twolpd / delta) * nextExponential(1);
500 y = FastMath.ceil(x);
501 v = -nextExponential(1) - delta * (x + 1) / twolpd;
502 }
503 }
504 a = x < 0 ? 1 : 0;
505 t = y * (y + 1) / (2 * lambda);
506 if (v < -t && a == 0) {
507 y = lambda + y;
508 break;
509 }
510 qr = t * ((2 * y + 1) / (6 * lambda) - 1);
511 qa = qr - (t * t) / (3 * (lambda + a * (y + 1)));
512 if (v < qa) {
513 y = lambda + y;
514 break;
515 }
516 if (v > qr) {
517 continue;
518 }
519 if (v < y * logLambda - CombinatoricsUtils.factorialLog((int) (y + lambda)) + logLambdaFactorial) {
520 y = lambda + y;
521 break;
522 }
523 }
524 return (int) FastMath.min(y2 + (long) y, Integer.MAX_VALUE);
525 }
526 }
527
528
529
530
531
532
533
534 public double nextDeviate(RealDistribution dist) {
535 return getSampler(dist).nextSample(this, dist);
536 }
537
538
539
540
541
542
543
544
545
546 public double[] nextDeviates(RealDistribution dist, int size) {
547
548
549 RealDistributionSampler sampler = getSampler(dist);
550 double[] out = new double[size];
551 for (int i = 0; i < size; i++) {
552 out[i] = sampler.nextSample(this, dist);
553 }
554 return out;
555 }
556
557
558
559
560
561
562
563 public int nextDeviate(IntegerDistribution dist) {
564 return getSampler(dist).nextSample(this, dist);
565 }
566
567
568
569
570
571
572
573
574
575 public int[] nextDeviates(IntegerDistribution dist, int size) {
576
577
578 IntegerDistributionSampler sampler = getSampler(dist);
579 int[] out = new int[size];
580 for (int i = 0; i < size; i++) {
581 out[i] = sampler.nextSample(this, dist);
582 }
583 return out;
584 }
585
586
587
588
589
590
591 private RealDistributionSampler getSampler(RealDistribution dist) {
592 RealDistributionSampler sampler = CONTINUOUS_SAMPLERS.get(dist.getClass());
593 if (sampler != null) {
594 return sampler;
595 }
596 return DEFAULT_REAL_SAMPLER;
597 }
598
599
600
601
602
603
604 private IntegerDistributionSampler getSampler(IntegerDistribution dist) {
605 IntegerDistributionSampler sampler = DISCRETE_SAMPLERS.get(dist.getClass());
606 if (sampler != null) {
607 return sampler;
608 }
609 return DEFAULT_INTEGER_SAMPLER;
610 }
611
612
613
614
615
616
617
618
619
620 public int nextInt(int lower, int upper) {
621 if (lower >= upper) {
622 throw new MathIllegalArgumentException(LocalizedCoreFormats.LOWER_BOUND_NOT_BELOW_UPPER_BOUND,
623 lower, upper);
624 }
625 final int max = (upper - lower) + 1;
626 if (max <= 0) {
627
628
629
630 while (true) {
631 final int r = nextInt();
632 if (r >= lower &&
633 r <= upper) {
634 return r;
635 }
636 }
637 } else {
638
639 return lower + nextInt(max);
640 }
641 }
642
643
644
645
646
647
648
649
650
651 public long nextLong(final long lower, final long upper) throws MathIllegalArgumentException {
652 if (lower >= upper) {
653 throw new MathIllegalArgumentException(LocalizedCoreFormats.LOWER_BOUND_NOT_BELOW_UPPER_BOUND,
654 lower, upper);
655 }
656 final long max = (upper - lower) + 1;
657 if (max <= 0) {
658
659
660 while (true) {
661 final long r = randomGenerator.nextLong();
662 if (r >= lower && r <= upper) {
663 return r;
664 }
665 }
666 } else if (max < Integer.MAX_VALUE){
667
668 return lower + randomGenerator.nextInt((int) max);
669 } else {
670
671 return lower + nextLong(max);
672 }
673 }
674
675
676
677
678
679
680
681
682 public double nextUniform(double lower, double upper) {
683 if (upper <= lower) {
684 throw new MathIllegalArgumentException(LocalizedCoreFormats.LOWER_BOUND_NOT_BELOW_UPPER_BOUND, lower, upper);
685 }
686 if (Double.isInfinite(lower) || Double.isInfinite(upper)) {
687 throw new MathIllegalArgumentException(LocalizedCoreFormats.INFINITE_BOUND);
688 }
689 if (Double.isNaN(lower) || Double.isNaN(upper)) {
690 throw new MathIllegalArgumentException(LocalizedCoreFormats.NAN_NOT_ALLOWED);
691 }
692 final double u = randomGenerator.nextDouble();
693 return u * upper + (1 - u) * lower;
694 }
695
696
697
698
699
700
701
702
703 public int nextZipf(int numberOfElements, double exponent) {
704 if (zipfSampler == null || zipfSampler.getExponent() != exponent || zipfSampler.getNumberOfElements() != numberOfElements) {
705 zipfSampler = new ZipfRejectionInversionSampler(numberOfElements, exponent);
706 }
707 return zipfSampler.sample(randomGenerator);
708 }
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728 public String nextHexString(int len) throws MathIllegalArgumentException {
729 if (len <= 0) {
730 throw new MathIllegalArgumentException(LocalizedCoreFormats.LENGTH, len);
731 }
732
733
734 StringBuilder outBuffer = new StringBuilder();
735
736
737 byte[] randomBytes = new byte[(len / 2) + 1];
738 randomGenerator.nextBytes(randomBytes);
739
740
741 for (int i = 0; i < randomBytes.length; i++) {
742 Integer c = Integer.valueOf(randomBytes[i]);
743
744
745
746
747
748
749 String hex = Integer.toHexString(c.intValue() + 128);
750
751
752 if (hex.length() == 1) {
753 outBuffer.append('0');
754 }
755 outBuffer.append(hex);
756 }
757 return outBuffer.toString().substring(0, len);
758 }
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779 public int[] nextPermutation(int n, int k)
780 throws MathIllegalArgumentException {
781 if (k > n) {
782 throw new MathIllegalArgumentException(LocalizedCoreFormats.PERMUTATION_EXCEEDS_N,
783 k, n, true);
784 }
785 if (k <= 0) {
786 throw new MathIllegalArgumentException(LocalizedCoreFormats.PERMUTATION_SIZE,
787 k);
788 }
789
790 final int[] index = MathArrays.natural(n);
791 MathArrays.shuffle(index, randomGenerator);
792
793
794 return Arrays.copyOf(index, k);
795 }
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817 public Object[] nextSample(Collection<?> c, int k) throws MathIllegalArgumentException {
818
819 int len = c.size();
820 if (k > len) {
821 throw new MathIllegalArgumentException(LocalizedCoreFormats.SAMPLE_SIZE_EXCEEDS_COLLECTION_SIZE,
822 k, len, true);
823 }
824 if (k <= 0) {
825 throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_OF_SAMPLES, k);
826 }
827
828 Object[] objects = c.toArray();
829 int[] index = nextPermutation(len, k);
830 Object[] result = new Object[k];
831 for (int i = 0; i < k; i++) {
832 result[i] = objects[index[i]];
833 }
834 return result;
835 }
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854 public double[] nextSample(double[] a, int k) throws MathIllegalArgumentException {
855 int len = a.length;
856 if (k > len) {
857 throw new MathIllegalArgumentException(LocalizedCoreFormats.SAMPLE_SIZE_EXCEEDS_COLLECTION_SIZE,
858 k, len, true);
859 }
860 if (k <= 0) {
861 throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_OF_SAMPLES, k);
862 }
863 int[] index = nextPermutation(len, k);
864 double[] result = new double[k];
865 for (int i = 0; i < k; i++) {
866 result[i] = a[index[i]];
867 }
868 return result;
869 }
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887 public int[] nextSampleWithReplacement(int sampleSize, double[] weights) {
888
889
890 if (sampleSize < 0) {
891 throw new MathIllegalArgumentException(LocalizedCoreFormats.NOT_POSITIVE_NUMBER_OF_SAMPLES);
892 }
893
894
895 double[] normWt = EnumeratedDistribution.checkAndNormalize(weights);
896
897
898 final int[] out = new int[sampleSize];
899 final int len = normWt.length;
900 for (int i = 0; i < sampleSize; i++) {
901 final double u = randomGenerator.nextDouble();
902 double cum = normWt[0];
903 int j = 1;
904 while (cum < u && j < len) {
905 cum += normWt[j++];
906 }
907 out[i] = --j;
908 }
909 return out;
910 }
911
912
913
914
915
916
917
918
919
920
921
922
923 private static class ChengBetaSampler {
924
925
926 private ChengBetaSampler() {
927
928 }
929
930
931
932
933
934
935
936
937
938
939 public static double sample(RandomGenerator generator,
940 double alpha,
941 double beta) {
942
943 final double a = FastMath.min(alpha, beta);
944 final double b = FastMath.max(alpha, beta);
945
946 if (a > 1) {
947 return algorithmBB(generator, alpha, a, b);
948 } else {
949 return algorithmBC(generator, alpha, b, a);
950 }
951 }
952
953
954
955
956
957
958
959
960
961
962
963 private static double algorithmBB(final RandomGenerator generator,
964 final double a0,
965 final double a,
966 final double b) {
967 final double alpha = a + b;
968 final double beta = FastMath.sqrt((alpha - 2.) / (2. * a * b - alpha));
969 final double gamma = a + 1. / beta;
970
971 double r;
972 double w;
973 double t;
974 do {
975 final double u1 = generator.nextDouble();
976 final double u2 = generator.nextDouble();
977 final double v = beta * (FastMath.log(u1) - FastMath.log1p(-u1));
978 w = a * FastMath.exp(v);
979 final double z = u1 * u1 * u2;
980 r = gamma * v - 1.3862944;
981 final double s = a + r - w;
982 if (s + 2.609438 >= 5 * z) {
983 break;
984 }
985
986 t = FastMath.log(z);
987 if (s >= t) {
988 break;
989 }
990 } while (r + alpha * (FastMath.log(alpha) - FastMath.log(b + w)) < t);
991
992 w = FastMath.min(w, Double.MAX_VALUE);
993 return Precision.equals(a, a0) ? w / (b + w) : b / (b + w);
994 }
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006 private static double algorithmBC(final RandomGenerator generator,
1007 final double a0,
1008 final double a,
1009 final double b) {
1010 final double alpha = a + b;
1011 final double beta = 1. / b;
1012 final double delta = 1. + a - b;
1013 final double k1 = delta * (0.0138889 + 0.0416667 * b) / (a * beta - 0.777778);
1014 final double k2 = 0.25 + (0.5 + 0.25 / delta) * b;
1015
1016 double w;
1017 for (;;) {
1018 final double u1 = generator.nextDouble();
1019 final double u2 = generator.nextDouble();
1020 final double y = u1 * u2;
1021 final double z = u1 * y;
1022 if (u1 < 0.5) {
1023 if (0.25 * u2 + z - y >= k1) {
1024 continue;
1025 }
1026 } else {
1027 if (z <= 0.25) {
1028 final double v = beta * (FastMath.log(u1) - FastMath.log1p(-u1));
1029 w = a * FastMath.exp(v);
1030 break;
1031 }
1032
1033 if (z >= k2) {
1034 continue;
1035 }
1036 }
1037
1038 final double v = beta * (FastMath.log(u1) - FastMath.log1p(-u1));
1039 w = a * FastMath.exp(v);
1040 if (alpha * (FastMath.log(alpha) - FastMath.log(b + w) + v) - 1.3862944 >= FastMath.log(z)) {
1041 break;
1042 }
1043 }
1044
1045 w = FastMath.min(w, Double.MAX_VALUE);
1046 return Precision.equals(a, a0) ? w / (b + w) : b / (b + w);
1047 }
1048 }
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073 static final class ZipfRejectionInversionSampler {
1074
1075
1076 private final double exponent;
1077
1078 private final int numberOfElements;
1079
1080 private final double hIntegralX1;
1081
1082 private final double hIntegralNumberOfElements;
1083
1084 private final double s;
1085
1086
1087
1088
1089
1090 ZipfRejectionInversionSampler(final int numberOfElements, final double exponent) {
1091 this.exponent = exponent;
1092 this.numberOfElements = numberOfElements;
1093 this.hIntegralX1 = hIntegral(1.5) - 1d;
1094 this.hIntegralNumberOfElements = hIntegral(numberOfElements + 0.5);
1095 this.s = 2d - hIntegralInverse(hIntegral(2.5) - h(2));
1096 }
1097
1098
1099
1100
1101
1102 int sample(final RandomGenerator random) {
1103 while(true) {
1104
1105 final double u = hIntegralNumberOfElements + random.nextDouble() * (hIntegralX1 - hIntegralNumberOfElements);
1106
1107
1108 double x = hIntegralInverse(u);
1109
1110 int k = (int)(x + 0.5);
1111
1112
1113
1114 if (k < 1) {
1115 k = 1;
1116 }
1117 else if (k > numberOfElements) {
1118 k = numberOfElements;
1119 }
1120
1121
1122
1123
1124
1125
1126
1127
1128 if (k - x <= s || u >= hIntegral(k + 0.5) - h(k)) {
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166 return k;
1167 }
1168 }
1169 }
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183 private double hIntegral(final double x) {
1184 final double logX = FastMath.log(x);
1185 return helper2((1d-exponent)*logX)*logX;
1186 }
1187
1188
1189
1190
1191
1192
1193
1194 private double h(final double x) {
1195 return FastMath.exp(-exponent * FastMath.log(x));
1196 }
1197
1198
1199
1200
1201
1202
1203
1204 private double hIntegralInverse(final double x) {
1205 double t = x*(1d-exponent);
1206 if (t < -1d) {
1207
1208
1209 t = -1;
1210 }
1211 return FastMath.exp(helper1(t)*x);
1212 }
1213
1214
1215
1216
1217 public double getExponent() {
1218 return exponent;
1219 }
1220
1221
1222
1223
1224 public int getNumberOfElements() {
1225 return numberOfElements;
1226 }
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236 static double helper1(final double x) {
1237 if (FastMath.abs(x)>1e-8) {
1238 return FastMath.log1p(x)/x;
1239 }
1240 else {
1241 return 1.-x*((1./2.)-x*((1./3.)-x*(1./4.)));
1242 }
1243 }
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253 static double helper2(final double x) {
1254 if (FastMath.abs(x)>1e-8) {
1255 return FastMath.expm1(x)/x;
1256 }
1257 else {
1258 return 1.+x*(1./2.)*(1.+x*(1./3.)*(1.+x*(1./4.)));
1259 }
1260 }
1261 }
1262
1263
1264
1265
1266
1267
1268 private final class EnumeratedDistributionSampler<T> {
1269
1270 private final double[] weights;
1271
1272 private final List<T> values;
1273
1274
1275
1276
1277
1278 EnumeratedDistributionSampler(List<Pair<T, Double>> pmf) {
1279 final int numMasses = pmf.size();
1280 weights = new double[numMasses];
1281 values = new ArrayList<>();
1282 for (int i = 0; i < numMasses; i++) {
1283 weights[i] = pmf.get(i).getSecond();
1284 values.add(pmf.get(i).getFirst());
1285 }
1286 }
1287
1288
1289
1290 public T sample() {
1291 int[] chosen = nextSampleWithReplacement(1, weights);
1292 return values.get(chosen[0]);
1293 }
1294 }
1295 }