HypergeometricDistribution.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      https://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. /*
  18.  * This is not the original file distributed by the Apache Software Foundation
  19.  * It has been modified by the Hipparchus project
  20.  */

  21. package org.hipparchus.distribution.discrete;

  22. import org.hipparchus.exception.LocalizedCoreFormats;
  23. import org.hipparchus.exception.MathIllegalArgumentException;
  24. import org.hipparchus.util.FastMath;

  25. /**
  26.  * Implementation of the hypergeometric distribution.
  27.  *
  28.  * @see <a href="http://en.wikipedia.org/wiki/Hypergeometric_distribution">Hypergeometric distribution (Wikipedia)</a>
  29.  * @see <a href="http://mathworld.wolfram.com/HypergeometricDistribution.html">Hypergeometric distribution (MathWorld)</a>
  30.  */
  31. public class HypergeometricDistribution extends AbstractIntegerDistribution {
  32.     /** Serializable version identifier. */
  33.     private static final long serialVersionUID = 20160320L;
  34.     /** The number of successes in the population. */
  35.     private final int numberOfSuccesses;
  36.     /** The population size. */
  37.     private final int populationSize;
  38.     /** The sample size. */
  39.     private final int sampleSize;
  40.     /** Cached numerical variance */
  41.     private final double numericalVariance;

  42.     /**
  43.      * Construct a new hypergeometric distribution with the specified population
  44.      * size, number of successes in the population, and sample size.
  45.      *
  46.      * @param populationSize Population size.
  47.      * @param numberOfSuccesses Number of successes in the population.
  48.      * @param sampleSize Sample size.
  49.      * @throws MathIllegalArgumentException if {@code numberOfSuccesses < 0}.
  50.      * @throws MathIllegalArgumentException if {@code populationSize <= 0}.
  51.      * @throws MathIllegalArgumentException if {@code numberOfSuccesses > populationSize},
  52.      * or {@code sampleSize > populationSize}.
  53.      */
  54.     public HypergeometricDistribution(int populationSize, int numberOfSuccesses, int sampleSize)
  55.         throws MathIllegalArgumentException {
  56.         if (populationSize <= 0) {
  57.             throw new MathIllegalArgumentException(LocalizedCoreFormats.POPULATION_SIZE,
  58.                                                    populationSize);
  59.         }
  60.         if (numberOfSuccesses < 0) {
  61.             throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_OF_SUCCESSES,
  62.                                                    numberOfSuccesses);
  63.         }
  64.         if (sampleSize < 0) {
  65.             throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_OF_SAMPLES,
  66.                                                    sampleSize);
  67.         }

  68.         if (numberOfSuccesses > populationSize) {
  69.             throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_OF_SUCCESS_LARGER_THAN_POPULATION_SIZE,
  70.                                                    numberOfSuccesses, populationSize, true);
  71.         }
  72.         if (sampleSize > populationSize) {
  73.             throw new MathIllegalArgumentException(LocalizedCoreFormats.SAMPLE_SIZE_LARGER_THAN_POPULATION_SIZE,
  74.                                                    sampleSize, populationSize, true);
  75.         }

  76.         this.numberOfSuccesses = numberOfSuccesses;
  77.         this.populationSize = populationSize;
  78.         this.sampleSize = sampleSize;
  79.         this.numericalVariance = calculateNumericalVariance();
  80.     }

  81.     /** {@inheritDoc} */
  82.     @Override
  83.     public double cumulativeProbability(int x) {
  84.         double ret;

  85.         int[] domain = getDomain(populationSize, numberOfSuccesses, sampleSize);
  86.         if (x < domain[0]) {
  87.             ret = 0.0;
  88.         } else if (x >= domain[1]) {
  89.             ret = 1.0;
  90.         } else {
  91.             ret = innerCumulativeProbability(domain[0], x, 1);
  92.         }

  93.         return ret;
  94.     }

  95.     /**
  96.      * Return the domain for the given hypergeometric distribution parameters.
  97.      *
  98.      * @param n Population size.
  99.      * @param m Number of successes in the population.
  100.      * @param k Sample size.
  101.      * @return a two element array containing the lower and upper bounds of the
  102.      * hypergeometric distribution.
  103.      */
  104.     private int[] getDomain(int n, int m, int k) {
  105.         return new int[] { getLowerDomain(n, m, k), getUpperDomain(m, k) };
  106.     }

  107.     /**
  108.      * Return the lowest domain value for the given hypergeometric distribution
  109.      * parameters.
  110.      *
  111.      * @param n Population size.
  112.      * @param m Number of successes in the population.
  113.      * @param k Sample size.
  114.      * @return the lowest domain value of the hypergeometric distribution.
  115.      */
  116.     private int getLowerDomain(int n, int m, int k) {
  117.         return FastMath.max(0, m - (n - k));
  118.     }

  119.     /**
  120.      * Access the number of successes.
  121.      *
  122.      * @return the number of successes.
  123.      */
  124.     public int getNumberOfSuccesses() {
  125.         return numberOfSuccesses;
  126.     }

  127.     /**
  128.      * Access the population size.
  129.      *
  130.      * @return the population size.
  131.      */
  132.     public int getPopulationSize() {
  133.         return populationSize;
  134.     }

  135.     /**
  136.      * Access the sample size.
  137.      *
  138.      * @return the sample size.
  139.      */
  140.     public int getSampleSize() {
  141.         return sampleSize;
  142.     }

  143.     /**
  144.      * Return the highest domain value for the given hypergeometric distribution
  145.      * parameters.
  146.      *
  147.      * @param m Number of successes in the population.
  148.      * @param k Sample size.
  149.      * @return the highest domain value of the hypergeometric distribution.
  150.      */
  151.     private int getUpperDomain(int m, int k) {
  152.         return FastMath.min(k, m);
  153.     }

  154.     /** {@inheritDoc} */
  155.     @Override
  156.     public double probability(int x) {
  157.         final double logProbability = logProbability(x);
  158.         return logProbability == Double.NEGATIVE_INFINITY ? 0 : FastMath.exp(logProbability);
  159.     }

  160.     /** {@inheritDoc} */
  161.     @Override
  162.     public double logProbability(int x) {
  163.         double ret;

  164.         int[] domain = getDomain(populationSize, numberOfSuccesses, sampleSize);
  165.         if (x < domain[0] || x > domain[1]) {
  166.             ret = Double.NEGATIVE_INFINITY;
  167.         } else {
  168.             double p = ((double) sampleSize) / populationSize;
  169.             double q = ((double) (populationSize - sampleSize)) / populationSize;
  170.             double p1 = SaddlePointExpansion.logBinomialProbability(x, numberOfSuccesses, p, q);
  171.             double p2 = SaddlePointExpansion.logBinomialProbability(sampleSize - x, populationSize - numberOfSuccesses, p, q);
  172.             double p3 = SaddlePointExpansion.logBinomialProbability(sampleSize, populationSize, p, q);
  173.             ret = p1 + p2 - p3;
  174.         }

  175.         return ret;
  176.     }

  177.     /**
  178.      * For this distribution, {@code X}, this method returns {@code P(X >= x)}.
  179.      *
  180.      * @param x Value at which the CDF is evaluated.
  181.      * @return the upper tail CDF for this distribution.
  182.      */
  183.     public double upperCumulativeProbability(int x) {
  184.         double ret;

  185.         final int[] domain = getDomain(populationSize, numberOfSuccesses, sampleSize);
  186.         if (x <= domain[0]) {
  187.             ret = 1.0;
  188.         } else if (x > domain[1]) {
  189.             ret = 0.0;
  190.         } else {
  191.             ret = innerCumulativeProbability(domain[1], x, -1);
  192.         }

  193.         return ret;
  194.     }

  195.     /**
  196.      * For this distribution, {@code X}, this method returns
  197.      * {@code P(x0 <= X <= x1)}.
  198.      * This probability is computed by summing the point probabilities for the
  199.      * values {@code x0, x0 + 1, x0 + 2, ..., x1}, in the order directed by
  200.      * {@code dx}.
  201.      *
  202.      * @param x0 Inclusive lower bound.
  203.      * @param x1 Inclusive upper bound.
  204.      * @param dx Direction of summation (1 indicates summing from x0 to x1, and
  205.      * 0 indicates summing from x1 to x0).
  206.      * @return {@code P(x0 <= X <= x1)}.
  207.      */
  208.     private double innerCumulativeProbability(int x0, int x1, int dx) {
  209.         double ret = probability(x0);
  210.         while (x0 != x1) {
  211.             x0 += dx;
  212.             ret += probability(x0);
  213.         }
  214.         return ret;
  215.     }

  216.     /**
  217.      * {@inheritDoc}
  218.      *
  219.      * For population size {@code N}, number of successes {@code m}, and sample
  220.      * size {@code n}, the mean is {@code n * m / N}.
  221.      */
  222.     @Override
  223.     public double getNumericalMean() {
  224.         return getSampleSize() * (getNumberOfSuccesses() / (double) getPopulationSize());
  225.     }

  226.     /**
  227.      * {@inheritDoc}
  228.      *
  229.      * For population size {@code N}, number of successes {@code m}, and sample
  230.      * size {@code n}, the variance is
  231.      * {@code [n * m * (N - n) * (N - m)] / [N^2 * (N - 1)]}.
  232.      */
  233.     @Override
  234.     public double getNumericalVariance() {
  235.         return numericalVariance;
  236.     }

  237.     /**
  238.      * Calculate the numerical variance.
  239.      *
  240.      * @return the variance of this distribution
  241.      */
  242.     private double calculateNumericalVariance() {
  243.         final double N = getPopulationSize();
  244.         final double m = getNumberOfSuccesses();
  245.         final double n = getSampleSize();
  246.         return (n * m * (N - n) * (N - m)) / (N * N * (N - 1));
  247.     }

  248.     /**
  249.      * {@inheritDoc}
  250.      *
  251.      * For population size {@code N}, number of successes {@code m}, and sample
  252.      * size {@code n}, the lower bound of the support is
  253.      * {@code max(0, n + m - N)}.
  254.      *
  255.      * @return lower bound of the support
  256.      */
  257.     @Override
  258.     public int getSupportLowerBound() {
  259.         return FastMath.max(0,
  260.                             getSampleSize() + getNumberOfSuccesses() - getPopulationSize());
  261.     }

  262.     /**
  263.      * {@inheritDoc}
  264.      *
  265.      * For number of successes {@code m} and sample size {@code n}, the upper
  266.      * bound of the support is {@code min(m, n)}.
  267.      *
  268.      * @return upper bound of the support
  269.      */
  270.     @Override
  271.     public int getSupportUpperBound() {
  272.         return FastMath.min(getNumberOfSuccesses(), getSampleSize());
  273.     }

  274.     /**
  275.      * {@inheritDoc}
  276.      *
  277.      * The support of this distribution is connected.
  278.      *
  279.      * @return {@code true}
  280.      */
  281.     @Override
  282.     public boolean isSupportConnected() {
  283.         return true;
  284.     }
  285. }