1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * https://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 /*
19 * This is not the original file distributed by the Apache Software Foundation
20 * It has been modified by the Hipparchus project
21 */
22 package org.hipparchus.distribution.discrete;
23
24 import java.util.ArrayList;
25 import java.util.HashMap;
26 import java.util.List;
27 import java.util.Map;
28 import java.util.Map.Entry;
29
30 import org.hipparchus.distribution.EnumeratedDistribution;
31 import org.hipparchus.exception.MathIllegalArgumentException;
32 import org.hipparchus.util.MathUtils;
33 import org.hipparchus.util.Pair;
34
35 /**
36 * Implementation of an integer-valued {@link EnumeratedDistribution}.
37 * <p>
38 * Values with zero-probability are allowed but they do not extend the
39 * support.
40 * <p>
41 * Duplicate values are allowed. Probabilities of duplicate values are combined
42 * when computing cumulative probabilities and statistics.
43 */
44 public class EnumeratedIntegerDistribution extends AbstractIntegerDistribution {
45
46 /** Serializable UID. */
47 private static final long serialVersionUID = 20130308L;
48
49 /**
50 * {@link EnumeratedDistribution} instance (using the {@link Integer} wrapper)
51 * used to generate the pmf.
52 */
53 private final EnumeratedDistribution<Integer> innerDistribution;
54
55 /**
56 * Create a discrete distribution using the given probability mass function
57 * definition.
58 *
59 * @param singletons array of random variable values.
60 * @param probabilities array of probabilities.
61 * @throws MathIllegalArgumentException if
62 * {@code singletons.length != probabilities.length}
63 * @throws MathIllegalArgumentException if probabilities contains negative, infinite or NaN values or only 0's
64 */
65 public EnumeratedIntegerDistribution(final int[] singletons, final double[] probabilities)
66 throws MathIllegalArgumentException {
67 innerDistribution =
68 new EnumeratedDistribution<>(createDistribution(singletons, probabilities));
69 }
70
71 /**
72 * Create a discrete integer-valued distribution from the input data. Values are assigned
73 * mass based on their frequency. For example, [0,1,1,2] as input creates a distribution
74 * with values 0, 1 and 2 having probability masses 0.25, 0.5 and 0.25 respectively,
75 *
76 * @param data input dataset
77 */
78 public EnumeratedIntegerDistribution(final int[] data) {
79 final Map<Integer, Integer> dataMap = new HashMap<>();
80 for (int value : data) {
81 Integer count = dataMap.get(value);
82 if (count == null) {
83 count = 0;
84 }
85 dataMap.put(value, ++count);
86 }
87 final int massPoints = dataMap.size();
88 final double denom = data.length;
89 final int[] values = new int[massPoints];
90 final double[] probabilities = new double[massPoints];
91 int index = 0;
92 for (Entry<Integer, Integer> entry : dataMap.entrySet()) {
93 values[index] = entry.getKey();
94 probabilities[index] = entry.getValue() / denom;
95 index++;
96 }
97 innerDistribution =
98 new EnumeratedDistribution<>(createDistribution(values, probabilities));
99 }
100
101 /**
102 * Create the list of Pairs representing the distribution from singletons and probabilities.
103 *
104 * @param singletons values
105 * @param probabilities probabilities
106 * @return list of value/probability pairs
107 * @throws MathIllegalArgumentException if probabilities contains negative, infinite or NaN values or only 0's
108 */
109 private static List<Pair<Integer, Double>> createDistribution(int[] singletons,
110 double[] probabilities) {
111 MathUtils.checkDimension(singletons.length, probabilities.length);
112 final List<Pair<Integer, Double>> samples = new ArrayList<>(singletons.length);
113
114 final double[] normalizedProbabilities = EnumeratedDistribution.checkAndNormalize(probabilities);
115 for (int i = 0; i < singletons.length; i++) {
116 samples.add(new Pair<>(singletons[i], normalizedProbabilities[i]));
117 }
118 return samples;
119 }
120
121 /**
122 * {@inheritDoc}
123 */
124 @Override
125 public double probability(final int x) {
126 return innerDistribution.probability(x);
127 }
128
129 /**
130 * {@inheritDoc}
131 */
132 @Override
133 public double cumulativeProbability(final int x) {
134 double probability = 0;
135
136 for (final Pair<Integer, Double> sample : innerDistribution.getPmf()) {
137 if (sample.getKey() <= x) {
138 probability += sample.getValue();
139 }
140 }
141
142 return probability;
143 }
144
145 /**
146 * {@inheritDoc}
147 *
148 * @return {@code sum(singletons[i] * probabilities[i])}
149 */
150 @Override
151 public double getNumericalMean() {
152 double mean = 0;
153
154 for (final Pair<Integer, Double> sample : innerDistribution.getPmf()) {
155 mean += sample.getValue() * sample.getKey();
156 }
157
158 return mean;
159 }
160
161 /**
162 * {@inheritDoc}
163 *
164 * @return {@code sum((singletons[i] - mean) ^ 2 * probabilities[i])}
165 */
166 @Override
167 public double getNumericalVariance() {
168 double mean = 0;
169 double meanOfSquares = 0;
170
171 for (final Pair<Integer, Double> sample : innerDistribution.getPmf()) {
172 mean += sample.getValue() * sample.getKey();
173 meanOfSquares += sample.getValue() * sample.getKey() * sample.getKey();
174 }
175
176 return meanOfSquares - mean * mean;
177 }
178
179 /**
180 * {@inheritDoc}
181 *
182 * Returns the lowest value with non-zero probability.
183 *
184 * @return the lowest value with non-zero probability.
185 */
186 @Override
187 public int getSupportLowerBound() {
188 int min = Integer.MAX_VALUE;
189 for (final Pair<Integer, Double> sample : innerDistribution.getPmf()) {
190 if (sample.getKey() < min && sample.getValue() > 0) {
191 min = sample.getKey();
192 }
193 }
194
195 return min;
196 }
197
198 /**
199 * {@inheritDoc}
200 *
201 * Returns the highest value with non-zero probability.
202 *
203 * @return the highest value with non-zero probability.
204 */
205 @Override
206 public int getSupportUpperBound() {
207 int max = Integer.MIN_VALUE;
208 for (final Pair<Integer, Double> sample : innerDistribution.getPmf()) {
209 if (sample.getKey() > max && sample.getValue() > 0) {
210 max = sample.getKey();
211 }
212 }
213
214 return max;
215 }
216
217 /**
218 * {@inheritDoc}
219 *
220 * The support of this distribution is connected.
221 *
222 * @return {@code true}
223 */
224 @Override
225 public boolean isSupportConnected() {
226 return true;
227 }
228
229 /**
230 * Return the probability mass function as a list of (value, probability) pairs.
231 *
232 * @return the probability mass function.
233 */
234 public List<Pair<Integer, Double>> getPmf() {
235 return innerDistribution.getPmf();
236 }
237
238 }