MixtureMultivariateNormalDistribution.java
- /*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- /*
- * This is not the original file distributed by the Apache Software Foundation
- * It has been modified by the Hipparchus project
- */
- package org.hipparchus.distribution.multivariate;
- import java.util.ArrayList;
- import java.util.List;
- import org.hipparchus.exception.MathIllegalArgumentException;
- import org.hipparchus.random.RandomGenerator;
- import org.hipparchus.util.Pair;
- /**
- * Multivariate normal mixture distribution.
- * This class is mainly syntactic sugar.
- *
- * @see MixtureMultivariateRealDistribution
- */
- public class MixtureMultivariateNormalDistribution
- extends MixtureMultivariateRealDistribution<MultivariateNormalDistribution> {
- /**
- * Creates a multivariate normal mixture distribution.
- * <p>
- * <b>Note:</b> this constructor will implicitly create an instance of
- * {@link org.hipparchus.random.Well19937c Well19937c} as random
- * generator to be used for sampling only (see {@link #sample()} and
- * {@link #sample(int)}). In case no sampling is needed for the created
- * distribution, it is advised to pass {@code null} as random generator via
- * the appropriate constructors to avoid the additional initialisation
- * overhead.
- *
- * @param weights Weights of each component.
- * @param means Mean vector for each component.
- * @param covariances Covariance matrix for each component.
- */
- public MixtureMultivariateNormalDistribution(double[] weights,
- double[][] means,
- double[][][] covariances) {
- super(createComponents(weights, means, covariances));
- }
- /**
- * Creates a mixture model from a list of distributions and their
- * associated weights.
- * <p>
- * <b>Note:</b> this constructor will implicitly create an instance of
- * {@link org.hipparchus.random.Well19937c Well19937c} as random
- * generator to be used for sampling only (see {@link #sample()} and
- * {@link #sample(int)}). In case no sampling is needed for the created
- * distribution, it is advised to pass {@code null} as random generator via
- * the appropriate constructors to avoid the additional initialisation
- * overhead.
- *
- * @param components List of (weight, distribution) pairs from which to sample.
- */
- public MixtureMultivariateNormalDistribution(List<Pair<Double, MultivariateNormalDistribution>> components) {
- super(components);
- }
- /**
- * Creates a mixture model from a list of distributions and their
- * associated weights.
- *
- * @param rng Random number generator.
- * @param components Distributions from which to sample.
- * @throws MathIllegalArgumentException if any of the weights is negative.
- * @throws MathIllegalArgumentException if not all components have the same
- * number of variables.
- */
- public MixtureMultivariateNormalDistribution(RandomGenerator rng,
- List<Pair<Double, MultivariateNormalDistribution>> components)
- throws MathIllegalArgumentException {
- super(rng, components);
- }
- /**
- * @param weights Weights of each component.
- * @param means Mean vector for each component.
- * @param covariances Covariance matrix for each component.
- * @return the list of components.
- */
- private static List<Pair<Double, MultivariateNormalDistribution>> createComponents(double[] weights,
- double[][] means,
- double[][][] covariances) {
- final List<Pair<Double, MultivariateNormalDistribution>> mvns = new ArrayList<>(weights.length);
- for (int i = 0; i < weights.length; i++) {
- final MultivariateNormalDistribution dist
- = new MultivariateNormalDistribution(means[i], covariances[i]);
- mvns.add(new Pair<>(weights[i], dist));
- }
- return mvns;
- }
- }