View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  /*
19   * This is not the original file distributed by the Apache Software Foundation
20   * It has been modified by the Hipparchus project
21   */
22  package org.hipparchus.distribution;
23  
24  import java.io.Serializable;
25  import java.util.ArrayList;
26  import java.util.List;
27  
28  import org.hipparchus.exception.LocalizedCoreFormats;
29  import org.hipparchus.exception.MathIllegalArgumentException;
30  import org.hipparchus.util.Pair;
31  import org.hipparchus.util.Precision;
32  
33  /**
34   * A generic implementation of a
35   * <a href="http://en.wikipedia.org/wiki/Probability_distribution#Discrete_probability_distribution">
36   * discrete probability distribution (Wikipedia)</a> over a finite sample space,
37   * based on an enumerated list of &lt;value, probability&gt; pairs.
38   * <p>
39   * Input probabilities must all be non-negative, but zero values are allowed and
40   * their sum does not have to equal one. Constructors will normalize input
41   * probabilities to make them sum to one.
42   * <p>
43   * The list of &lt;value, probability&gt; pairs does not, strictly speaking, have
44   * to be a function and it can contain null values.  The pmf created by the constructor
45   * will combine probabilities of equal values and will treat null values as equal.
46   * <p>
47   * For example, if the list of pairs &lt;"dog", 0.2&gt;, &lt;null, 0.1&gt;,
48   * &lt;"pig", 0.2&gt;, &lt;"dog", 0.1&gt;, &lt;null, 0.4&gt; is provided to the
49   * constructor, the resulting pmf will assign mass of 0.5 to null, 0.3 to "dog"
50   * and 0.2 to null.
51   *
52   * @param <T> type of the elements in the sample space.
53   */
54  public class EnumeratedDistribution<T> implements Serializable {
55  
56      /** Serializable UID. */
57      private static final long serialVersionUID = 20123308L;
58  
59      /**
60       * List of random variable values.
61       */
62      private final List<T> singletons;
63  
64      /**
65       * Probabilities of respective random variable values. For i = 0, ..., singletons.size() - 1,
66       * probability[i] is the probability that a random variable following this distribution takes
67       * the value singletons[i].
68       */
69      private final double[] probabilities;
70  
71      /**
72       * Create an enumerated distribution using the given probability mass function
73       * enumeration.
74       *
75       * @param pmf probability mass function enumerated as a list of &lt;T, probability&gt;
76       * pairs.
77       * @throws MathIllegalArgumentException of weights includes negative, NaN or infinite values or only 0's
78       */
79      public EnumeratedDistribution(final List<Pair<T, Double>> pmf)
80          throws MathIllegalArgumentException {
81  
82          singletons = new ArrayList<>(pmf.size());
83          final double[] probs = new double[pmf.size()];
84  
85          for (int i = 0; i < pmf.size(); i++) {
86              final Pair<T, Double> sample = pmf.get(i);
87              singletons.add(sample.getKey());
88              final double p = sample.getValue();
89              probs[i] = p;
90          }
91  
92          probabilities = checkAndNormalize(probs);
93  
94      }
95  
96      /**
97       * For a random variable {@code X} whose values are distributed according to
98       * this distribution, this method returns {@code P(X = x)}. In other words,
99       * this method represents the probability mass function (PMF) for the
100      * distribution.
101      * <p>
102      * Note that if {@code x1} and {@code x2} satisfy {@code x1.equals(x2)},
103      * or both are null, then {@code probability(x1) = probability(x2)}.
104      *
105      * @param x the point at which the PMF is evaluated
106      * @return the value of the probability mass function at {@code x}
107      */
108     public double probability(final T x) {
109         double probability = 0;
110 
111         for (int i = 0; i < probabilities.length; i++) {
112             if ((x == null && singletons.get(i) == null) ||
113                 (x != null && x.equals(singletons.get(i)))) {
114                 probability += probabilities[i];
115             }
116         }
117 
118         return probability;
119     }
120 
121     /**
122      * Return the probability mass function as a list of (value, probability) pairs.
123      * <p>
124      * Note that if duplicate and / or null values were provided to the constructor
125      * when creating this EnumeratedDistribution, the returned list will contain these
126      * values.  If duplicates values exist, what is returned will not represent
127      * a pmf (i.e., it is up to the caller to consolidate duplicate mass points).
128      *
129      * @return the probability mass function.
130      */
131     public List<Pair<T, Double>> getPmf() {
132         final List<Pair<T, Double>> samples = new ArrayList<>(probabilities.length);
133 
134         for (int i = 0; i < probabilities.length; i++) {
135             samples.add(new Pair<>(singletons.get(i), probabilities[i]));
136         }
137 
138         return samples;
139     }
140 
141     /**
142      * Checks to make sure that weights is neither null nor empty and contains only non-negative, finite,
143      * non-NaN values and if necessary normalizes it to sum to 1.
144      *
145      * @param weights input array to be used as the basis for the values of a PMF
146      * @return a possibly rescaled copy of the array that sums to 1 and contains only valid probability values
147      * @throws MathIllegalArgumentException of weights is null or empty or includes negative, NaN or
148      *         infinite values or only 0's
149      */
150     public static double[] checkAndNormalize(double[] weights) {
151         if (weights == null || weights.length == 0) {
152             throw new MathIllegalArgumentException(LocalizedCoreFormats.ARRAY_ZERO_LENGTH_OR_NULL_NOT_ALLOWED);
153         }
154         final int len = weights.length;
155         double sumWt = 0;
156         boolean posWt = false;
157         for (int i = 0; i < len; i++) {
158             if (weights[i] < 0) {
159                 throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL,
160                                                        weights[i], 0);
161             }
162             if (weights[i] > 0) {
163                 posWt = true;
164             }
165             if (Double.isNaN(weights[i])) {
166                 throw new MathIllegalArgumentException(LocalizedCoreFormats.NAN_ELEMENT_AT_INDEX, i);
167             }
168             if (Double.isInfinite(weights[i])) {
169                 throw new MathIllegalArgumentException(LocalizedCoreFormats.INFINITE_ARRAY_ELEMENT,
170                                                        weights[i], i);
171             }
172             sumWt += weights[i];
173         }
174         if (!posWt) {
175             throw new MathIllegalArgumentException(LocalizedCoreFormats.WEIGHT_AT_LEAST_ONE_NON_ZERO);
176         }
177         double[] normWt;
178         if (Precision.equals(sumWt, 1d, 10)) { // allow small error (10 ulps)
179             normWt = weights;
180         } else {
181             normWt = new double[len];
182             for (int i = 0; i < len; i++) {
183                 normWt[i] = weights[i] / sumWt;
184             }
185         }
186         return normWt;
187     }
188 
189 }