1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22 package org.hipparchus.distribution.continuous;
23
24 import java.util.ArrayList;
25 import java.util.HashMap;
26 import java.util.List;
27 import java.util.Map;
28 import java.util.Map.Entry;
29
30 import org.hipparchus.distribution.EnumeratedDistribution;
31 import org.hipparchus.exception.MathIllegalArgumentException;
32 import org.hipparchus.util.MathArrays;
33 import org.hipparchus.util.MathUtils;
34 import org.hipparchus.util.Pair;
35
36
37
38
39
40
41
42
43
44
45 public class EnumeratedRealDistribution extends AbstractRealDistribution {
46
47
48 private static final long serialVersionUID = 20130308L;
49
50
51
52
53
54 private final EnumeratedDistribution<Double> innerDistribution;
55
56
57
58
59
60
61
62
63 public EnumeratedRealDistribution(final double[] data) {
64 super();
65 final Map<Double, Integer> dataMap = new HashMap<>();
66 for (double value : data) {
67 Integer count = dataMap.get(value);
68 if (count == null) {
69 count = 0;
70 }
71 dataMap.put(value, ++count);
72 }
73 final int massPoints = dataMap.size();
74 final double denom = data.length;
75 final double[] values = new double[massPoints];
76 final double[] probabilities = new double[massPoints];
77 int index = 0;
78 for (Entry<Double, Integer> entry : dataMap.entrySet()) {
79 values[index] = entry.getKey();
80 probabilities[index] = entry.getValue().intValue() / denom;
81 index++;
82 }
83 innerDistribution =
84 new EnumeratedDistribution<>(createDistribution(values, probabilities));
85 }
86
87
88
89
90
91
92
93
94
95
96
97
98
99 public EnumeratedRealDistribution(final double[] singletons, final double[] probabilities)
100 throws MathIllegalArgumentException {
101 super();
102 innerDistribution =
103 new EnumeratedDistribution<>(createDistribution(singletons, probabilities));
104 }
105
106
107
108
109
110
111
112
113
114
115 private static List<Pair<Double, Double>> createDistribution(double[] singletons,
116 double[] probabilities) {
117 MathArrays.checkEqualLength(singletons, probabilities);
118 final List<Pair<Double, Double>> samples = new ArrayList<>(singletons.length);
119
120 final double[] normalizedProbabilities = EnumeratedDistribution.checkAndNormalize(probabilities);
121 for (int i = 0; i < singletons.length; i++) {
122 samples.add(new Pair<>(singletons[i], normalizedProbabilities[i]));
123 }
124 return samples;
125 }
126
127
128
129
130
131
132
133
134
135
136
137
138
139 public double probability(final double x) {
140 return innerDistribution.probability(x);
141 }
142
143
144
145
146
147
148
149
150
151
152 @Override
153 public double density(final double x) {
154 return probability(x);
155 }
156
157
158
159
160 @Override
161 public double cumulativeProbability(final double x) {
162 double probability = 0;
163
164 for (final Pair<Double, Double> sample : innerDistribution.getPmf()) {
165 if (sample.getKey() <= x) {
166 probability += sample.getValue();
167 }
168 }
169
170 return probability;
171 }
172
173
174
175
176 @Override
177 public double inverseCumulativeProbability(final double p) throws MathIllegalArgumentException {
178 MathUtils.checkRangeInclusive(p, 0, 1);
179
180 double probability = 0;
181 double x = getSupportLowerBound();
182 for (final Pair<Double, Double> sample : innerDistribution.getPmf()) {
183 if (sample.getValue() == 0.0) {
184 continue;
185 }
186
187 probability += sample.getValue();
188 x = sample.getKey();
189
190 if (probability >= p) {
191 break;
192 }
193 }
194
195 return x;
196 }
197
198
199
200
201
202
203 @Override
204 public double getNumericalMean() {
205 double mean = 0;
206
207 for (final Pair<Double, Double> sample : innerDistribution.getPmf()) {
208 mean += sample.getValue() * sample.getKey();
209 }
210
211 return mean;
212 }
213
214
215
216
217
218
219 @Override
220 public double getNumericalVariance() {
221 double mean = 0;
222 double meanOfSquares = 0;
223
224 for (final Pair<Double, Double> sample : innerDistribution.getPmf()) {
225 mean += sample.getValue() * sample.getKey();
226 meanOfSquares += sample.getValue() * sample.getKey() * sample.getKey();
227 }
228
229 return meanOfSquares - mean * mean;
230 }
231
232
233
234
235
236
237
238
239 @Override
240 public double getSupportLowerBound() {
241 double min = Double.POSITIVE_INFINITY;
242 for (final Pair<Double, Double> sample : innerDistribution.getPmf()) {
243 if (sample.getKey() < min && sample.getValue() > 0) {
244 min = sample.getKey();
245 }
246 }
247
248 return min;
249 }
250
251
252
253
254
255
256
257
258 @Override
259 public double getSupportUpperBound() {
260 double max = Double.NEGATIVE_INFINITY;
261 for (final Pair<Double, Double> sample : innerDistribution.getPmf()) {
262 if (sample.getKey() > max && sample.getValue() > 0) {
263 max = sample.getKey();
264 }
265 }
266
267 return max;
268 }
269
270
271
272
273
274
275
276
277 @Override
278 public boolean isSupportConnected() {
279 return true;
280 }
281
282
283
284
285
286
287 public List<Pair<Double, Double>> getPmf() {
288 return innerDistribution.getPmf();
289 }
290 }