MixtureMultivariateRealDistribution.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      https://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. /*
  18.  * This is not the original file distributed by the Apache Software Foundation
  19.  * It has been modified by the Hipparchus project
  20.  */
  21. package org.hipparchus.distribution.multivariate;

  22. import java.util.ArrayList;
  23. import java.util.List;

  24. import org.hipparchus.distribution.MultivariateRealDistribution;
  25. import org.hipparchus.exception.LocalizedCoreFormats;
  26. import org.hipparchus.exception.MathIllegalArgumentException;
  27. import org.hipparchus.exception.MathRuntimeException;
  28. import org.hipparchus.random.RandomGenerator;
  29. import org.hipparchus.random.Well19937c;
  30. import org.hipparchus.util.Pair;

  31. /**
  32.  * Class for representing <a href="http://en.wikipedia.org/wiki/Mixture_model">
  33.  * mixture model</a> distributions.
  34.  *
  35.  * @param <T> Type of the mixture components.
  36.  */
  37. public class MixtureMultivariateRealDistribution<T extends MultivariateRealDistribution>
  38.     extends AbstractMultivariateRealDistribution {
  39.     /** Normalized weight of each mixture component. */
  40.     private final double[] weight;
  41.     /** Mixture components. */
  42.     private final List<T> distribution;

  43.     /**
  44.      * Creates a mixture model from a list of distributions and their
  45.      * associated weights.
  46.      * <p>
  47.      * <b>Note:</b> this constructor will implicitly create an instance of
  48.      * {@link Well19937c} as random generator to be used for sampling only (see
  49.      * {@link #sample()} and {@link #sample(int)}). In case no sampling is
  50.      * needed for the created distribution, it is advised to pass {@code null}
  51.      * as random generator via the appropriate constructors to avoid the
  52.      * additional initialisation overhead.
  53.      *
  54.      * @param components List of (weight, distribution) pairs from which to sample.
  55.      */
  56.     public MixtureMultivariateRealDistribution(List<Pair<Double, T>> components) {
  57.         this(new Well19937c(), components);
  58.     }

  59.     /**
  60.      * Creates a mixture model from a list of distributions and their
  61.      * associated weights.
  62.      *
  63.      * @param rng Random number generator.
  64.      * @param components Distributions from which to sample.
  65.      * @throws MathIllegalArgumentException if any of the weights is negative.
  66.      * @throws MathIllegalArgumentException if not all components have the same
  67.      * number of variables.
  68.      */
  69.     public MixtureMultivariateRealDistribution(RandomGenerator rng,
  70.                                                List<Pair<Double, T>> components) {
  71.         super(rng, components.get(0).getSecond().getDimension());

  72.         final int numComp = components.size();
  73.         final int dim = getDimension();
  74.         double weightSum = 0;
  75.         for (final Pair<Double, T> comp : components) {
  76.             if (comp.getSecond().getDimension() != dim) {
  77.                 throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH,
  78.                         comp.getSecond().getDimension(), dim);
  79.             }
  80.             if (comp.getFirst() < 0) {
  81.                 throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL, comp.getFirst(), 0);
  82.             }
  83.             weightSum += comp.getFirst();
  84.         }

  85.         // Check for overflow.
  86.         if (Double.isInfinite(weightSum)) {
  87.             throw new MathRuntimeException(LocalizedCoreFormats.OVERFLOW);
  88.         }

  89.         // Store each distribution and its normalized weight.
  90.         distribution = new ArrayList<>();
  91.         weight = new double[numComp];
  92.         for (int i = 0; i < numComp; i++) {
  93.             final Pair<Double, T> comp = components.get(i);
  94.             weight[i] = comp.getFirst() / weightSum;
  95.             distribution.add(comp.getSecond());
  96.         }
  97.     }

  98.     /** {@inheritDoc} */
  99.     @Override
  100.     public double density(final double[] values) {
  101.         double p = 0;
  102.         for (int i = 0; i < weight.length; i++) {
  103.             p += weight[i] * distribution.get(i).density(values);
  104.         }
  105.         return p;
  106.     }

  107.     /** {@inheritDoc} */
  108.     @Override
  109.     public double[] sample() {
  110.         // Sampled values.
  111.         double[] vals = null;

  112.         // Determine which component to sample from.
  113.         final double randomValue = random.nextDouble();
  114.         double sum = 0;

  115.         for (int i = 0; i < weight.length; i++) {
  116.             sum += weight[i];
  117.             if (randomValue <= sum) {
  118.                 // pick model i
  119.                 vals = distribution.get(i).sample();
  120.                 break;
  121.             }
  122.         }

  123.         if (vals == null) {
  124.             // This should never happen, but it ensures we won't return a null in
  125.             // case the loop above has some floating point inequality problem on
  126.             // the final iteration.
  127.             vals = distribution.get(weight.length - 1).sample();
  128.         }

  129.         return vals;
  130.     }

  131.     /** {@inheritDoc} */
  132.     @Override
  133.     public void reseedRandomGenerator(long seed) {
  134.         // Seed needs to be propagated to underlying components
  135.         // in order to maintain consistency between runs.
  136.         super.reseedRandomGenerator(seed);

  137.         for (int i = 0; i < distribution.size(); i++) {
  138.             // Make each component's seed different in order to avoid
  139.             // using the same sequence of random numbers.
  140.             distribution.get(i).reseedRandomGenerator(i + 1 + seed);
  141.         }
  142.     }

  143.     /**
  144.      * Gets the distributions that make up the mixture model.
  145.      *
  146.      * @return the component distributions and associated weights.
  147.      */
  148.     public List<Pair<Double, T>> getComponents() {
  149.         final List<Pair<Double, T>> list = new ArrayList<>(weight.length);

  150.         for (int i = 0; i < weight.length; i++) {
  151.             list.add(new Pair<>(weight[i], distribution.get(i)));
  152.         }

  153.         return list;
  154.     }
  155. }