EnumeratedDistribution.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      https://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. /*
  18.  * This is not the original file distributed by the Apache Software Foundation
  19.  * It has been modified by the Hipparchus project
  20.  */
  21. package org.hipparchus.distribution;

  22. import java.io.Serializable;
  23. import java.util.ArrayList;
  24. import java.util.List;

  25. import org.hipparchus.exception.LocalizedCoreFormats;
  26. import org.hipparchus.exception.MathIllegalArgumentException;
  27. import org.hipparchus.util.Pair;
  28. import org.hipparchus.util.Precision;

  29. /**
  30.  * A generic implementation of a
  31.  * <a href="http://en.wikipedia.org/wiki/Probability_distribution#Discrete_probability_distribution">
  32.  * discrete probability distribution (Wikipedia)</a> over a finite sample space,
  33.  * based on an enumerated list of &lt;value, probability&gt; pairs.
  34.  * <p>
  35.  * Input probabilities must all be non-negative, but zero values are allowed and
  36.  * their sum does not have to equal one. Constructors will normalize input
  37.  * probabilities to make them sum to one.
  38.  * <p>
  39.  * The list of &lt;value, probability&gt; pairs does not, strictly speaking, have
  40.  * to be a function and it can contain null values.  The pmf created by the constructor
  41.  * will combine probabilities of equal values and will treat null values as equal.
  42.  * <p>
  43.  * For example, if the list of pairs &lt;"dog", 0.2&gt;, &lt;null, 0.1&gt;,
  44.  * &lt;"pig", 0.2&gt;, &lt;"dog", 0.1&gt;, &lt;null, 0.4&gt; is provided to the
  45.  * constructor, the resulting pmf will assign mass of 0.5 to null, 0.3 to "dog"
  46.  * and 0.2 to null.
  47.  *
  48.  * @param <T> type of the elements in the sample space.
  49.  */
  50. public class EnumeratedDistribution<T> implements Serializable {

  51.     /** Serializable UID. */
  52.     private static final long serialVersionUID = 20123308L;

  53.     /**
  54.      * List of random variable values.
  55.      */
  56.     private final List<T> singletons;

  57.     /**
  58.      * Probabilities of respective random variable values. For i = 0, ..., singletons.size() - 1,
  59.      * probability[i] is the probability that a random variable following this distribution takes
  60.      * the value singletons[i].
  61.      */
  62.     private final double[] probabilities;

  63.     /**
  64.      * Create an enumerated distribution using the given probability mass function
  65.      * enumeration.
  66.      *
  67.      * @param pmf probability mass function enumerated as a list of &lt;T, probability&gt;
  68.      * pairs.
  69.      * @throws MathIllegalArgumentException of weights includes negative, NaN or infinite values or only 0's
  70.      */
  71.     public EnumeratedDistribution(final List<Pair<T, Double>> pmf)
  72.         throws MathIllegalArgumentException {

  73.         singletons = new ArrayList<>(pmf.size());
  74.         final double[] probs = new double[pmf.size()];

  75.         for (int i = 0; i < pmf.size(); i++) {
  76.             final Pair<T, Double> sample = pmf.get(i);
  77.             singletons.add(sample.getKey());
  78.             final double p = sample.getValue();
  79.             probs[i] = p;
  80.         }

  81.         probabilities = checkAndNormalize(probs);

  82.     }

  83.     /**
  84.      * For a random variable {@code X} whose values are distributed according to
  85.      * this distribution, this method returns {@code P(X = x)}. In other words,
  86.      * this method represents the probability mass function (PMF) for the
  87.      * distribution.
  88.      * <p>
  89.      * Note that if {@code x1} and {@code x2} satisfy {@code x1.equals(x2)},
  90.      * or both are null, then {@code probability(x1) = probability(x2)}.
  91.      *
  92.      * @param x the point at which the PMF is evaluated
  93.      * @return the value of the probability mass function at {@code x}
  94.      */
  95.     public double probability(final T x) {
  96.         double probability = 0;

  97.         for (int i = 0; i < probabilities.length; i++) {
  98.             if ((x == null && singletons.get(i) == null) ||
  99.                 (x != null && x.equals(singletons.get(i)))) {
  100.                 probability += probabilities[i];
  101.             }
  102.         }

  103.         return probability;
  104.     }

  105.     /**
  106.      * Return the probability mass function as a list of (value, probability) pairs.
  107.      * <p>
  108.      * Note that if duplicate and / or null values were provided to the constructor
  109.      * when creating this EnumeratedDistribution, the returned list will contain these
  110.      * values.  If duplicates values exist, what is returned will not represent
  111.      * a pmf (i.e., it is up to the caller to consolidate duplicate mass points).
  112.      *
  113.      * @return the probability mass function.
  114.      */
  115.     public List<Pair<T, Double>> getPmf() {
  116.         final List<Pair<T, Double>> samples = new ArrayList<>(probabilities.length);

  117.         for (int i = 0; i < probabilities.length; i++) {
  118.             samples.add(new Pair<>(singletons.get(i), probabilities[i]));
  119.         }

  120.         return samples;
  121.     }

  122.     /**
  123.      * Checks to make sure that weights is neither null nor empty and contains only non-negative, finite,
  124.      * non-NaN values and if necessary normalizes it to sum to 1.
  125.      *
  126.      * @param weights input array to be used as the basis for the values of a PMF
  127.      * @return a possibly rescaled copy of the array that sums to 1 and contains only valid probability values
  128.      * @throws MathIllegalArgumentException of weights is null or empty or includes negative, NaN or
  129.      *         infinite values or only 0's
  130.      */
  131.     public static double[] checkAndNormalize(double[] weights) {
  132.         if (weights == null || weights.length == 0) {
  133.             throw new MathIllegalArgumentException(LocalizedCoreFormats.ARRAY_ZERO_LENGTH_OR_NULL_NOT_ALLOWED);
  134.         }
  135.         final int len = weights.length;
  136.         double sumWt = 0;
  137.         boolean posWt = false;
  138.         for (int i = 0; i < len; i++) {
  139.             if (weights[i] < 0) {
  140.                 throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL,
  141.                                                        weights[i], 0);
  142.             }
  143.             if (weights[i] > 0) {
  144.                 posWt = true;
  145.             }
  146.             if (Double.isNaN(weights[i])) {
  147.                 throw new MathIllegalArgumentException(LocalizedCoreFormats.NAN_ELEMENT_AT_INDEX, i);
  148.             }
  149.             if (Double.isInfinite(weights[i])) {
  150.                 throw new MathIllegalArgumentException(LocalizedCoreFormats.INFINITE_ARRAY_ELEMENT,
  151.                                                        weights[i], i);
  152.             }
  153.             sumWt += weights[i];
  154.         }
  155.         if (!posWt) {
  156.             throw new MathIllegalArgumentException(LocalizedCoreFormats.WEIGHT_AT_LEAST_ONE_NON_ZERO);
  157.         }
  158.         double[] normWt;
  159.         if (Precision.equals(sumWt, 1d, 10)) { // allow small error (10 ulps)
  160.             normWt = weights;
  161.         } else {
  162.             normWt = new double[len];
  163.             for (int i = 0; i < len; i++) {
  164.                 normWt[i] = weights[i] / sumWt;
  165.             }
  166.         }
  167.         return normWt;
  168.     }

  169. }