View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  /*
19   * This is not the original file distributed by the Apache Software Foundation
20   * It has been modified by the Hipparchus project
21   */
22  package org.hipparchus.distribution.multivariate;
23  
24  import java.util.ArrayList;
25  import java.util.List;
26  
27  import org.hipparchus.exception.MathIllegalArgumentException;
28  import org.hipparchus.random.RandomGenerator;
29  import org.hipparchus.util.Pair;
30  
31  /**
32   * Multivariate normal mixture distribution.
33   * This class is mainly syntactic sugar.
34   *
35   * @see MixtureMultivariateRealDistribution
36   */
37  public class MixtureMultivariateNormalDistribution
38      extends MixtureMultivariateRealDistribution<MultivariateNormalDistribution> {
39  
40      /**
41       * Creates a multivariate normal mixture distribution.
42       * <p>
43       * <b>Note:</b> this constructor will implicitly create an instance of
44       * {@link org.hipparchus.random.Well19937c Well19937c} as random
45       * generator to be used for sampling only (see {@link #sample()} and
46       * {@link #sample(int)}). In case no sampling is needed for the created
47       * distribution, it is advised to pass {@code null} as random generator via
48       * the appropriate constructors to avoid the additional initialisation
49       * overhead.
50       *
51       * @param weights Weights of each component.
52       * @param means Mean vector for each component.
53       * @param covariances Covariance matrix for each component.
54       */
55      public MixtureMultivariateNormalDistribution(double[] weights,
56                                                   double[][] means,
57                                                   double[][][] covariances) {
58          super(createComponents(weights, means, covariances));
59      }
60  
61      /**
62       * Creates a mixture model from a list of distributions and their
63       * associated weights.
64       * <p>
65       * <b>Note:</b> this constructor will implicitly create an instance of
66       * {@link org.hipparchus.random.Well19937c Well19937c} as random
67       * generator to be used for sampling only (see {@link #sample()} and
68       * {@link #sample(int)}). In case no sampling is needed for the created
69       * distribution, it is advised to pass {@code null} as random generator via
70       * the appropriate constructors to avoid the additional initialisation
71       * overhead.
72       *
73       * @param components List of (weight, distribution) pairs from which to sample.
74       */
75      public MixtureMultivariateNormalDistribution(List<Pair<Double, MultivariateNormalDistribution>> components) {
76          super(components);
77      }
78  
79      /**
80       * Creates a mixture model from a list of distributions and their
81       * associated weights.
82       *
83       * @param rng Random number generator.
84       * @param components Distributions from which to sample.
85       * @throws MathIllegalArgumentException if any of the weights is negative.
86       * @throws MathIllegalArgumentException if not all components have the same
87       * number of variables.
88       */
89      public MixtureMultivariateNormalDistribution(RandomGenerator rng,
90                                                   List<Pair<Double, MultivariateNormalDistribution>> components)
91          throws MathIllegalArgumentException {
92          super(rng, components);
93      }
94  
95      /**
96       * @param weights Weights of each component.
97       * @param means Mean vector for each component.
98       * @param covariances Covariance matrix for each component.
99       * @return the list of components.
100      */
101     private static List<Pair<Double, MultivariateNormalDistribution>> createComponents(double[] weights,
102                                                                                        double[][] means,
103                                                                                        double[][][] covariances) {
104         final List<Pair<Double, MultivariateNormalDistribution>> mvns = new ArrayList<>(weights.length);
105 
106         for (int i = 0; i < weights.length; i++) {
107             final MultivariateNormalDistribution dist
108                 = new MultivariateNormalDistribution(means[i], covariances[i]);
109 
110             mvns.add(new Pair<>(weights[i], dist));
111         }
112 
113         return mvns;
114     }
115 }