MixtureMultivariateNormalDistribution.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      https://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. /*
  18.  * This is not the original file distributed by the Apache Software Foundation
  19.  * It has been modified by the Hipparchus project
  20.  */
  21. package org.hipparchus.distribution.multivariate;

  22. import java.util.ArrayList;
  23. import java.util.List;

  24. import org.hipparchus.exception.MathIllegalArgumentException;
  25. import org.hipparchus.random.RandomGenerator;
  26. import org.hipparchus.util.Pair;

  27. /**
  28.  * Multivariate normal mixture distribution.
  29.  * This class is mainly syntactic sugar.
  30.  *
  31.  * @see MixtureMultivariateRealDistribution
  32.  */
  33. public class MixtureMultivariateNormalDistribution
  34.     extends MixtureMultivariateRealDistribution<MultivariateNormalDistribution> {

  35.     /**
  36.      * Creates a multivariate normal mixture distribution.
  37.      * <p>
  38.      * <b>Note:</b> this constructor will implicitly create an instance of
  39.      * {@link org.hipparchus.random.Well19937c Well19937c} as random
  40.      * generator to be used for sampling only (see {@link #sample()} and
  41.      * {@link #sample(int)}). In case no sampling is needed for the created
  42.      * distribution, it is advised to pass {@code null} as random generator via
  43.      * the appropriate constructors to avoid the additional initialisation
  44.      * overhead.
  45.      *
  46.      * @param weights Weights of each component.
  47.      * @param means Mean vector for each component.
  48.      * @param covariances Covariance matrix for each component.
  49.      */
  50.     public MixtureMultivariateNormalDistribution(double[] weights,
  51.                                                  double[][] means,
  52.                                                  double[][][] covariances) {
  53.         super(createComponents(weights, means, covariances));
  54.     }

  55.     /**
  56.      * Creates a mixture model from a list of distributions and their
  57.      * associated weights.
  58.      * <p>
  59.      * <b>Note:</b> this constructor will implicitly create an instance of
  60.      * {@link org.hipparchus.random.Well19937c Well19937c} as random
  61.      * generator to be used for sampling only (see {@link #sample()} and
  62.      * {@link #sample(int)}). In case no sampling is needed for the created
  63.      * distribution, it is advised to pass {@code null} as random generator via
  64.      * the appropriate constructors to avoid the additional initialisation
  65.      * overhead.
  66.      *
  67.      * @param components List of (weight, distribution) pairs from which to sample.
  68.      */
  69.     public MixtureMultivariateNormalDistribution(List<Pair<Double, MultivariateNormalDistribution>> components) {
  70.         super(components);
  71.     }

  72.     /**
  73.      * Creates a mixture model from a list of distributions and their
  74.      * associated weights.
  75.      *
  76.      * @param rng Random number generator.
  77.      * @param components Distributions from which to sample.
  78.      * @throws MathIllegalArgumentException if any of the weights is negative.
  79.      * @throws MathIllegalArgumentException if not all components have the same
  80.      * number of variables.
  81.      */
  82.     public MixtureMultivariateNormalDistribution(RandomGenerator rng,
  83.                                                  List<Pair<Double, MultivariateNormalDistribution>> components)
  84.         throws MathIllegalArgumentException {
  85.         super(rng, components);
  86.     }

  87.     /**
  88.      * @param weights Weights of each component.
  89.      * @param means Mean vector for each component.
  90.      * @param covariances Covariance matrix for each component.
  91.      * @return the list of components.
  92.      */
  93.     private static List<Pair<Double, MultivariateNormalDistribution>> createComponents(double[] weights,
  94.                                                                                        double[][] means,
  95.                                                                                        double[][][] covariances) {
  96.         final List<Pair<Double, MultivariateNormalDistribution>> mvns = new ArrayList<>(weights.length);

  97.         for (int i = 0; i < weights.length; i++) {
  98.             final MultivariateNormalDistribution dist
  99.                 = new MultivariateNormalDistribution(means[i], covariances[i]);

  100.             mvns.add(new Pair<>(weights[i], dist));
  101.         }

  102.         return mvns;
  103.     }
  104. }