1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * https://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 /*
19 * This is not the original file distributed by the Apache Software Foundation
20 * It has been modified by the Hipparchus project
21 */
22 package org.hipparchus.distribution;
23
24 import java.io.Serializable;
25 import java.util.ArrayList;
26 import java.util.List;
27
28 import org.hipparchus.exception.LocalizedCoreFormats;
29 import org.hipparchus.exception.MathIllegalArgumentException;
30 import org.hipparchus.util.Pair;
31 import org.hipparchus.util.Precision;
32
33 /**
34 * A generic implementation of a
35 * <a href="http://en.wikipedia.org/wiki/Probability_distribution#Discrete_probability_distribution">
36 * discrete probability distribution (Wikipedia)</a> over a finite sample space,
37 * based on an enumerated list of <value, probability> pairs.
38 * <p>
39 * Input probabilities must all be non-negative, but zero values are allowed and
40 * their sum does not have to equal one. Constructors will normalize input
41 * probabilities to make them sum to one.
42 * <p>
43 * The list of <value, probability> pairs does not, strictly speaking, have
44 * to be a function and it can contain null values. The pmf created by the constructor
45 * will combine probabilities of equal values and will treat null values as equal.
46 * <p>
47 * For example, if the list of pairs <"dog", 0.2>, <null, 0.1>,
48 * <"pig", 0.2>, <"dog", 0.1>, <null, 0.4> is provided to the
49 * constructor, the resulting pmf will assign mass of 0.5 to null, 0.3 to "dog"
50 * and 0.2 to null.
51 *
52 * @param <T> type of the elements in the sample space.
53 */
54 public class EnumeratedDistribution<T> implements Serializable {
55
56 /** Serializable UID. */
57 private static final long serialVersionUID = 20123308L;
58
59 /**
60 * List of random variable values.
61 */
62 private final List<T> singletons;
63
64 /**
65 * Probabilities of respective random variable values. For i = 0, ..., singletons.size() - 1,
66 * probability[i] is the probability that a random variable following this distribution takes
67 * the value singletons[i].
68 */
69 private final double[] probabilities;
70
71 /**
72 * Create an enumerated distribution using the given probability mass function
73 * enumeration.
74 *
75 * @param pmf probability mass function enumerated as a list of <T, probability>
76 * pairs.
77 * @throws MathIllegalArgumentException of weights includes negative, NaN or infinite values or only 0's
78 */
79 public EnumeratedDistribution(final List<Pair<T, Double>> pmf)
80 throws MathIllegalArgumentException {
81
82 singletons = new ArrayList<>(pmf.size());
83 final double[] probs = new double[pmf.size()];
84
85 for (int i = 0; i < pmf.size(); i++) {
86 final Pair<T, Double> sample = pmf.get(i);
87 singletons.add(sample.getKey());
88 final double p = sample.getValue();
89 probs[i] = p;
90 }
91
92 probabilities = checkAndNormalize(probs);
93
94 }
95
96 /**
97 * For a random variable {@code X} whose values are distributed according to
98 * this distribution, this method returns {@code P(X = x)}. In other words,
99 * this method represents the probability mass function (PMF) for the
100 * distribution.
101 * <p>
102 * Note that if {@code x1} and {@code x2} satisfy {@code x1.equals(x2)},
103 * or both are null, then {@code probability(x1) = probability(x2)}.
104 *
105 * @param x the point at which the PMF is evaluated
106 * @return the value of the probability mass function at {@code x}
107 */
108 public double probability(final T x) {
109 double probability = 0;
110
111 for (int i = 0; i < probabilities.length; i++) {
112 if ((x == null && singletons.get(i) == null) ||
113 (x != null && x.equals(singletons.get(i)))) {
114 probability += probabilities[i];
115 }
116 }
117
118 return probability;
119 }
120
121 /**
122 * Return the probability mass function as a list of (value, probability) pairs.
123 * <p>
124 * Note that if duplicate and / or null values were provided to the constructor
125 * when creating this EnumeratedDistribution, the returned list will contain these
126 * values. If duplicates values exist, what is returned will not represent
127 * a pmf (i.e., it is up to the caller to consolidate duplicate mass points).
128 *
129 * @return the probability mass function.
130 */
131 public List<Pair<T, Double>> getPmf() {
132 final List<Pair<T, Double>> samples = new ArrayList<>(probabilities.length);
133
134 for (int i = 0; i < probabilities.length; i++) {
135 samples.add(new Pair<>(singletons.get(i), probabilities[i]));
136 }
137
138 return samples;
139 }
140
141 /**
142 * Checks to make sure that weights is neither null nor empty and contains only non-negative, finite,
143 * non-NaN values and if necessary normalizes it to sum to 1.
144 *
145 * @param weights input array to be used as the basis for the values of a PMF
146 * @return a possibly rescaled copy of the array that sums to 1 and contains only valid probability values
147 * @throws MathIllegalArgumentException of weights is null or empty or includes negative, NaN or
148 * infinite values or only 0's
149 */
150 public static double[] checkAndNormalize(double[] weights) {
151 if (weights == null || weights.length == 0) {
152 throw new MathIllegalArgumentException(LocalizedCoreFormats.ARRAY_ZERO_LENGTH_OR_NULL_NOT_ALLOWED);
153 }
154 final int len = weights.length;
155 double sumWt = 0;
156 boolean posWt = false;
157 for (int i = 0; i < len; i++) {
158 if (weights[i] < 0) {
159 throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL,
160 weights[i], 0);
161 }
162 if (weights[i] > 0) {
163 posWt = true;
164 }
165 if (Double.isNaN(weights[i])) {
166 throw new MathIllegalArgumentException(LocalizedCoreFormats.NAN_ELEMENT_AT_INDEX, i);
167 }
168 if (Double.isInfinite(weights[i])) {
169 throw new MathIllegalArgumentException(LocalizedCoreFormats.INFINITE_ARRAY_ELEMENT,
170 weights[i], i);
171 }
172 sumWt += weights[i];
173 }
174 if (!posWt) {
175 throw new MathIllegalArgumentException(LocalizedCoreFormats.WEIGHT_AT_LEAST_ONE_NON_ZERO);
176 }
177 double[] normWt;
178 if (Precision.equals(sumWt, 1d, 10)) { // allow small error (10 ulps)
179 normWt = weights;
180 } else {
181 normWt = new double[len];
182 for (int i = 0; i < len; i++) {
183 normWt[i] = weights[i] / sumWt;
184 }
185 }
186 return normWt;
187 }
188
189 }