1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * https://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 /*
19 * This is not the original file distributed by the Apache Software Foundation
20 * It has been modified by the Hipparchus project
21 */
22 package org.hipparchus.distribution.multivariate;
23
24 import java.util.ArrayList;
25 import java.util.List;
26
27 import org.hipparchus.exception.MathIllegalArgumentException;
28 import org.hipparchus.random.RandomGenerator;
29 import org.hipparchus.util.Pair;
30
31 /**
32 * Multivariate normal mixture distribution.
33 * This class is mainly syntactic sugar.
34 *
35 * @see MixtureMultivariateRealDistribution
36 */
37 public class MixtureMultivariateNormalDistribution
38 extends MixtureMultivariateRealDistribution<MultivariateNormalDistribution> {
39
40 /**
41 * Creates a multivariate normal mixture distribution.
42 * <p>
43 * <b>Note:</b> this constructor will implicitly create an instance of
44 * {@link org.hipparchus.random.Well19937c Well19937c} as random
45 * generator to be used for sampling only (see {@link #sample()} and
46 * {@link #sample(int)}). In case no sampling is needed for the created
47 * distribution, it is advised to pass {@code null} as random generator via
48 * the appropriate constructors to avoid the additional initialisation
49 * overhead.
50 *
51 * @param weights Weights of each component.
52 * @param means Mean vector for each component.
53 * @param covariances Covariance matrix for each component.
54 */
55 public MixtureMultivariateNormalDistribution(double[] weights,
56 double[][] means,
57 double[][][] covariances) {
58 super(createComponents(weights, means, covariances));
59 }
60
61 /**
62 * Creates a mixture model from a list of distributions and their
63 * associated weights.
64 * <p>
65 * <b>Note:</b> this constructor will implicitly create an instance of
66 * {@link org.hipparchus.random.Well19937c Well19937c} as random
67 * generator to be used for sampling only (see {@link #sample()} and
68 * {@link #sample(int)}). In case no sampling is needed for the created
69 * distribution, it is advised to pass {@code null} as random generator via
70 * the appropriate constructors to avoid the additional initialisation
71 * overhead.
72 *
73 * @param components List of (weight, distribution) pairs from which to sample.
74 */
75 public MixtureMultivariateNormalDistribution(List<Pair<Double, MultivariateNormalDistribution>> components) {
76 super(components);
77 }
78
79 /**
80 * Creates a mixture model from a list of distributions and their
81 * associated weights.
82 *
83 * @param rng Random number generator.
84 * @param components Distributions from which to sample.
85 * @throws MathIllegalArgumentException if any of the weights is negative.
86 * @throws MathIllegalArgumentException if not all components have the same
87 * number of variables.
88 */
89 public MixtureMultivariateNormalDistribution(RandomGenerator rng,
90 List<Pair<Double, MultivariateNormalDistribution>> components)
91 throws MathIllegalArgumentException {
92 super(rng, components);
93 }
94
95 /**
96 * @param weights Weights of each component.
97 * @param means Mean vector for each component.
98 * @param covariances Covariance matrix for each component.
99 * @return the list of components.
100 */
101 private static List<Pair<Double, MultivariateNormalDistribution>> createComponents(double[] weights,
102 double[][] means,
103 double[][][] covariances) {
104 final List<Pair<Double, MultivariateNormalDistribution>> mvns = new ArrayList<>(weights.length);
105
106 for (int i = 0; i < weights.length; i++) {
107 final MultivariateNormalDistribution dist
108 = new MultivariateNormalDistribution(means[i], covariances[i]);
109
110 mvns.add(new Pair<>(weights[i], dist));
111 }
112
113 return mvns;
114 }
115 }