1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * https://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18 /* 19 * This is not the original file distributed by the Apache Software Foundation 20 * It has been modified by the Hipparchus project 21 */ 22 package org.hipparchus.distribution; 23 24 import java.io.Serializable; 25 import java.util.ArrayList; 26 import java.util.List; 27 28 import org.hipparchus.exception.LocalizedCoreFormats; 29 import org.hipparchus.exception.MathIllegalArgumentException; 30 import org.hipparchus.util.Pair; 31 import org.hipparchus.util.Precision; 32 33 /** 34 * A generic implementation of a 35 * <a href="http://en.wikipedia.org/wiki/Probability_distribution#Discrete_probability_distribution"> 36 * discrete probability distribution (Wikipedia)</a> over a finite sample space, 37 * based on an enumerated list of <value, probability> pairs. 38 * <p> 39 * Input probabilities must all be non-negative, but zero values are allowed and 40 * their sum does not have to equal one. Constructors will normalize input 41 * probabilities to make them sum to one. 42 * <p> 43 * The list of <value, probability> pairs does not, strictly speaking, have 44 * to be a function and it can contain null values. The pmf created by the constructor 45 * will combine probabilities of equal values and will treat null values as equal. 46 * <p> 47 * For example, if the list of pairs <"dog", 0.2>, <null, 0.1>, 48 * <"pig", 0.2>, <"dog", 0.1>, <null, 0.4> is provided to the 49 * constructor, the resulting pmf will assign mass of 0.5 to null, 0.3 to "dog" 50 * and 0.2 to null. 51 * 52 * @param <T> type of the elements in the sample space. 53 */ 54 public class EnumeratedDistribution<T> implements Serializable { 55 56 /** Serializable UID. */ 57 private static final long serialVersionUID = 20123308L; 58 59 /** 60 * List of random variable values. 61 */ 62 private final List<T> singletons; 63 64 /** 65 * Probabilities of respective random variable values. For i = 0, ..., singletons.size() - 1, 66 * probability[i] is the probability that a random variable following this distribution takes 67 * the value singletons[i]. 68 */ 69 private final double[] probabilities; 70 71 /** 72 * Create an enumerated distribution using the given probability mass function 73 * enumeration. 74 * 75 * @param pmf probability mass function enumerated as a list of <T, probability> 76 * pairs. 77 * @throws MathIllegalArgumentException of weights includes negative, NaN or infinite values or only 0's 78 */ 79 public EnumeratedDistribution(final List<Pair<T, Double>> pmf) 80 throws MathIllegalArgumentException { 81 82 singletons = new ArrayList<>(pmf.size()); 83 final double[] probs = new double[pmf.size()]; 84 85 for (int i = 0; i < pmf.size(); i++) { 86 final Pair<T, Double> sample = pmf.get(i); 87 singletons.add(sample.getKey()); 88 final double p = sample.getValue(); 89 probs[i] = p; 90 } 91 92 probabilities = checkAndNormalize(probs); 93 94 } 95 96 /** 97 * For a random variable {@code X} whose values are distributed according to 98 * this distribution, this method returns {@code P(X = x)}. In other words, 99 * this method represents the probability mass function (PMF) for the 100 * distribution. 101 * <p> 102 * Note that if {@code x1} and {@code x2} satisfy {@code x1.equals(x2)}, 103 * or both are null, then {@code probability(x1) = probability(x2)}. 104 * 105 * @param x the point at which the PMF is evaluated 106 * @return the value of the probability mass function at {@code x} 107 */ 108 public double probability(final T x) { 109 double probability = 0; 110 111 for (int i = 0; i < probabilities.length; i++) { 112 if ((x == null && singletons.get(i) == null) || 113 (x != null && x.equals(singletons.get(i)))) { 114 probability += probabilities[i]; 115 } 116 } 117 118 return probability; 119 } 120 121 /** 122 * Return the probability mass function as a list of (value, probability) pairs. 123 * <p> 124 * Note that if duplicate and / or null values were provided to the constructor 125 * when creating this EnumeratedDistribution, the returned list will contain these 126 * values. If duplicates values exist, what is returned will not represent 127 * a pmf (i.e., it is up to the caller to consolidate duplicate mass points). 128 * 129 * @return the probability mass function. 130 */ 131 public List<Pair<T, Double>> getPmf() { 132 final List<Pair<T, Double>> samples = new ArrayList<>(probabilities.length); 133 134 for (int i = 0; i < probabilities.length; i++) { 135 samples.add(new Pair<>(singletons.get(i), probabilities[i])); 136 } 137 138 return samples; 139 } 140 141 /** 142 * Checks to make sure that weights is neither null nor empty and contains only non-negative, finite, 143 * non-NaN values and if necessary normalizes it to sum to 1. 144 * 145 * @param weights input array to be used as the basis for the values of a PMF 146 * @return a possibly rescaled copy of the array that sums to 1 and contains only valid probability values 147 * @throws MathIllegalArgumentException of weights is null or empty or includes negative, NaN or 148 * infinite values or only 0's 149 */ 150 public static double[] checkAndNormalize(double[] weights) { 151 if (weights == null || weights.length == 0) { 152 throw new MathIllegalArgumentException(LocalizedCoreFormats.ARRAY_ZERO_LENGTH_OR_NULL_NOT_ALLOWED); 153 } 154 final int len = weights.length; 155 double sumWt = 0; 156 boolean posWt = false; 157 for (int i = 0; i < len; i++) { 158 if (weights[i] < 0) { 159 throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL, 160 weights[i], 0); 161 } 162 if (weights[i] > 0) { 163 posWt = true; 164 } 165 if (Double.isNaN(weights[i])) { 166 throw new MathIllegalArgumentException(LocalizedCoreFormats.NAN_ELEMENT_AT_INDEX, i); 167 } 168 if (Double.isInfinite(weights[i])) { 169 throw new MathIllegalArgumentException(LocalizedCoreFormats.INFINITE_ARRAY_ELEMENT, 170 weights[i], i); 171 } 172 sumWt += weights[i]; 173 } 174 if (!posWt) { 175 throw new MathIllegalArgumentException(LocalizedCoreFormats.WEIGHT_AT_LEAST_ONE_NON_ZERO); 176 } 177 double[] normWt; 178 if (Precision.equals(sumWt, 1d, 10)) { // allow small error (10 ulps) 179 normWt = weights; 180 } else { 181 normWt = new double[len]; 182 for (int i = 0; i < len; i++) { 183 normWt[i] = weights[i] / sumWt; 184 } 185 } 186 return normWt; 187 } 188 189 }