MixtureMultivariateRealDistribution.java
- /*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- /*
- * This is not the original file distributed by the Apache Software Foundation
- * It has been modified by the Hipparchus project
- */
- package org.hipparchus.distribution.multivariate;
- import java.util.ArrayList;
- import java.util.List;
- import org.hipparchus.distribution.MultivariateRealDistribution;
- import org.hipparchus.exception.LocalizedCoreFormats;
- import org.hipparchus.exception.MathIllegalArgumentException;
- import org.hipparchus.exception.MathRuntimeException;
- import org.hipparchus.random.RandomGenerator;
- import org.hipparchus.random.Well19937c;
- import org.hipparchus.util.Pair;
- /**
- * Class for representing <a href="http://en.wikipedia.org/wiki/Mixture_model">
- * mixture model</a> distributions.
- *
- * @param <T> Type of the mixture components.
- */
- public class MixtureMultivariateRealDistribution<T extends MultivariateRealDistribution>
- extends AbstractMultivariateRealDistribution {
- /** Normalized weight of each mixture component. */
- private final double[] weight;
- /** Mixture components. */
- private final List<T> distribution;
- /**
- * Creates a mixture model from a list of distributions and their
- * associated weights.
- * <p>
- * <b>Note:</b> this constructor will implicitly create an instance of
- * {@link Well19937c} as random generator to be used for sampling only (see
- * {@link #sample()} and {@link #sample(int)}). In case no sampling is
- * needed for the created distribution, it is advised to pass {@code null}
- * as random generator via the appropriate constructors to avoid the
- * additional initialisation overhead.
- *
- * @param components List of (weight, distribution) pairs from which to sample.
- */
- public MixtureMultivariateRealDistribution(List<Pair<Double, T>> components) {
- this(new Well19937c(), components);
- }
- /**
- * Creates a mixture model from a list of distributions and their
- * associated weights.
- *
- * @param rng Random number generator.
- * @param components Distributions from which to sample.
- * @throws MathIllegalArgumentException if any of the weights is negative.
- * @throws MathIllegalArgumentException if not all components have the same
- * number of variables.
- */
- public MixtureMultivariateRealDistribution(RandomGenerator rng,
- List<Pair<Double, T>> components) {
- super(rng, components.get(0).getSecond().getDimension());
- final int numComp = components.size();
- final int dim = getDimension();
- double weightSum = 0;
- for (final Pair<Double, T> comp : components) {
- if (comp.getSecond().getDimension() != dim) {
- throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH,
- comp.getSecond().getDimension(), dim);
- }
- if (comp.getFirst() < 0) {
- throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL, comp.getFirst(), 0);
- }
- weightSum += comp.getFirst();
- }
- // Check for overflow.
- if (Double.isInfinite(weightSum)) {
- throw new MathRuntimeException(LocalizedCoreFormats.OVERFLOW);
- }
- // Store each distribution and its normalized weight.
- distribution = new ArrayList<>();
- weight = new double[numComp];
- for (int i = 0; i < numComp; i++) {
- final Pair<Double, T> comp = components.get(i);
- weight[i] = comp.getFirst() / weightSum;
- distribution.add(comp.getSecond());
- }
- }
- /** {@inheritDoc} */
- @Override
- public double density(final double[] values) {
- double p = 0;
- for (int i = 0; i < weight.length; i++) {
- p += weight[i] * distribution.get(i).density(values);
- }
- return p;
- }
- /** {@inheritDoc} */
- @Override
- public double[] sample() {
- // Sampled values.
- double[] vals = null;
- // Determine which component to sample from.
- final double randomValue = random.nextDouble();
- double sum = 0;
- for (int i = 0; i < weight.length; i++) {
- sum += weight[i];
- if (randomValue <= sum) {
- // pick model i
- vals = distribution.get(i).sample();
- break;
- }
- }
- if (vals == null) {
- // This should never happen, but it ensures we won't return a null in
- // case the loop above has some floating point inequality problem on
- // the final iteration.
- vals = distribution.get(weight.length - 1).sample();
- }
- return vals;
- }
- /** {@inheritDoc} */
- @Override
- public void reseedRandomGenerator(long seed) {
- // Seed needs to be propagated to underlying components
- // in order to maintain consistency between runs.
- super.reseedRandomGenerator(seed);
- for (int i = 0; i < distribution.size(); i++) {
- // Make each component's seed different in order to avoid
- // using the same sequence of random numbers.
- distribution.get(i).reseedRandomGenerator(i + 1 + seed);
- }
- }
- /**
- * Gets the distributions that make up the mixture model.
- *
- * @return the component distributions and associated weights.
- */
- public List<Pair<Double, T>> getComponents() {
- final List<Pair<Double, T>> list = new ArrayList<>(weight.length);
- for (int i = 0; i < weight.length; i++) {
- list.add(new Pair<>(weight[i], distribution.get(i)));
- }
- return list;
- }
- }