EnumeratedDistribution.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* This is not the original file distributed by the Apache Software Foundation
* It has been modified by the Hipparchus project
*/
package org.hipparchus.distribution;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import org.hipparchus.exception.LocalizedCoreFormats;
import org.hipparchus.exception.MathIllegalArgumentException;
import org.hipparchus.util.Pair;
import org.hipparchus.util.Precision;
/**
* A generic implementation of a
* <a href="http://en.wikipedia.org/wiki/Probability_distribution#Discrete_probability_distribution">
* discrete probability distribution (Wikipedia)</a> over a finite sample space,
* based on an enumerated list of <value, probability> pairs.
* <p>
* Input probabilities must all be non-negative, but zero values are allowed and
* their sum does not have to equal one. Constructors will normalize input
* probabilities to make them sum to one.
* <p>
* The list of <value, probability> pairs does not, strictly speaking, have
* to be a function and it can contain null values. The pmf created by the constructor
* will combine probabilities of equal values and will treat null values as equal.
* <p>
* For example, if the list of pairs <"dog", 0.2>, <null, 0.1>,
* <"pig", 0.2>, <"dog", 0.1>, <null, 0.4> is provided to the
* constructor, the resulting pmf will assign mass of 0.5 to null, 0.3 to "dog"
* and 0.2 to null.
*
* @param <T> type of the elements in the sample space.
*/
public class EnumeratedDistribution<T> implements Serializable {
/** Serializable UID. */
private static final long serialVersionUID = 20123308L;
/**
* List of random variable values.
*/
private final List<T> singletons;
/**
* Probabilities of respective random variable values. For i = 0, ..., singletons.size() - 1,
* probability[i] is the probability that a random variable following this distribution takes
* the value singletons[i].
*/
private final double[] probabilities;
/**
* Create an enumerated distribution using the given probability mass function
* enumeration.
*
* @param pmf probability mass function enumerated as a list of <T, probability>
* pairs.
* @throws MathIllegalArgumentException of weights includes negative, NaN or infinite values or only 0's
*/
public EnumeratedDistribution(final List<Pair<T, Double>> pmf)
throws MathIllegalArgumentException {
singletons = new ArrayList<>(pmf.size());
final double[] probs = new double[pmf.size()];
for (int i = 0; i < pmf.size(); i++) {
final Pair<T, Double> sample = pmf.get(i);
singletons.add(sample.getKey());
final double p = sample.getValue();
probs[i] = p;
}
probabilities = checkAndNormalize(probs);
}
/**
* For a random variable {@code X} whose values are distributed according to
* this distribution, this method returns {@code P(X = x)}. In other words,
* this method represents the probability mass function (PMF) for the
* distribution.
* <p>
* Note that if {@code x1} and {@code x2} satisfy {@code x1.equals(x2)},
* or both are null, then {@code probability(x1) = probability(x2)}.
*
* @param x the point at which the PMF is evaluated
* @return the value of the probability mass function at {@code x}
*/
public double probability(final T x) {
double probability = 0;
for (int i = 0; i < probabilities.length; i++) {
if ((x == null && singletons.get(i) == null) ||
(x != null && x.equals(singletons.get(i)))) {
probability += probabilities[i];
}
}
return probability;
}
/**
* Return the probability mass function as a list of (value, probability) pairs.
* <p>
* Note that if duplicate and / or null values were provided to the constructor
* when creating this EnumeratedDistribution, the returned list will contain these
* values. If duplicates values exist, what is returned will not represent
* a pmf (i.e., it is up to the caller to consolidate duplicate mass points).
*
* @return the probability mass function.
*/
public List<Pair<T, Double>> getPmf() {
final List<Pair<T, Double>> samples = new ArrayList<>(probabilities.length);
for (int i = 0; i < probabilities.length; i++) {
samples.add(new Pair<>(singletons.get(i), probabilities[i]));
}
return samples;
}
/**
* Checks to make sure that weights is neither null nor empty and contains only non-negative, finite,
* non-NaN values and if necessary normalizes it to sum to 1.
*
* @param weights input array to be used as the basis for the values of a PMF
* @return a possibly rescaled copy of the array that sums to 1 and contains only valid probability values
* @throws MathIllegalArgumentException of weights is null or empty or includes negative, NaN or
* infinite values or only 0's
*/
public static double[] checkAndNormalize(double[] weights) {
if (weights == null || weights.length == 0) {
throw new MathIllegalArgumentException(LocalizedCoreFormats.ARRAY_ZERO_LENGTH_OR_NULL_NOT_ALLOWED);
}
final int len = weights.length;
double sumWt = 0;
boolean posWt = false;
for (int i = 0; i < len; i++) {
if (weights[i] < 0) {
throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL,
weights[i], 0);
}
if (weights[i] > 0) {
posWt = true;
}
if (Double.isNaN(weights[i])) {
throw new MathIllegalArgumentException(LocalizedCoreFormats.NAN_ELEMENT_AT_INDEX, i);
}
if (Double.isInfinite(weights[i])) {
throw new MathIllegalArgumentException(LocalizedCoreFormats.INFINITE_ARRAY_ELEMENT,
weights[i], i);
}
sumWt += weights[i];
}
if (!posWt) {
throw new MathIllegalArgumentException(LocalizedCoreFormats.WEIGHT_AT_LEAST_ONE_NON_ZERO);
}
double[] normWt;
if (Precision.equals(sumWt, 1d, 10)) { // allow small error (10 ulps)
normWt = weights;
} else {
normWt = new double[len];
for (int i = 0; i < len; i++) {
normWt[i] = weights[i] / sumWt;
}
}
return normWt;
}
}