1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * https://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 /*
19 * This is not the original file distributed by the Apache Software Foundation
20 * It has been modified by the Hipparchus project
21 */
22 package org.hipparchus.distribution.multivariate;
23
24 import java.util.ArrayList;
25 import java.util.List;
26
27 import org.hipparchus.distribution.MultivariateRealDistribution;
28 import org.hipparchus.exception.LocalizedCoreFormats;
29 import org.hipparchus.exception.MathIllegalArgumentException;
30 import org.hipparchus.exception.MathRuntimeException;
31 import org.hipparchus.random.RandomGenerator;
32 import org.hipparchus.random.Well19937c;
33 import org.hipparchus.util.Pair;
34
35 /**
36 * Class for representing <a href="http://en.wikipedia.org/wiki/Mixture_model">
37 * mixture model</a> distributions.
38 *
39 * @param <T> Type of the mixture components.
40 */
41 public class MixtureMultivariateRealDistribution<T extends MultivariateRealDistribution>
42 extends AbstractMultivariateRealDistribution {
43 /** Normalized weight of each mixture component. */
44 private final double[] weight;
45 /** Mixture components. */
46 private final List<T> distribution;
47
48 /**
49 * Creates a mixture model from a list of distributions and their
50 * associated weights.
51 * <p>
52 * <b>Note:</b> this constructor will implicitly create an instance of
53 * {@link Well19937c} as random generator to be used for sampling only (see
54 * {@link #sample()} and {@link #sample(int)}). In case no sampling is
55 * needed for the created distribution, it is advised to pass {@code null}
56 * as random generator via the appropriate constructors to avoid the
57 * additional initialisation overhead.
58 *
59 * @param components List of (weight, distribution) pairs from which to sample.
60 */
61 public MixtureMultivariateRealDistribution(List<Pair<Double, T>> components) {
62 this(new Well19937c(), components);
63 }
64
65 /**
66 * Creates a mixture model from a list of distributions and their
67 * associated weights.
68 *
69 * @param rng Random number generator.
70 * @param components Distributions from which to sample.
71 * @throws MathIllegalArgumentException if any of the weights is negative.
72 * @throws MathIllegalArgumentException if not all components have the same
73 * number of variables.
74 */
75 public MixtureMultivariateRealDistribution(RandomGenerator rng,
76 List<Pair<Double, T>> components) {
77 super(rng, components.get(0).getSecond().getDimension());
78
79 final int numComp = components.size();
80 final int dim = getDimension();
81 double weightSum = 0;
82 for (final Pair<Double, T> comp : components) {
83 if (comp.getSecond().getDimension() != dim) {
84 throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH,
85 comp.getSecond().getDimension(), dim);
86 }
87 if (comp.getFirst() < 0) {
88 throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL, comp.getFirst(), 0);
89 }
90 weightSum += comp.getFirst();
91 }
92
93 // Check for overflow.
94 if (Double.isInfinite(weightSum)) {
95 throw new MathRuntimeException(LocalizedCoreFormats.OVERFLOW);
96 }
97
98 // Store each distribution and its normalized weight.
99 distribution = new ArrayList<>();
100 weight = new double[numComp];
101 for (int i = 0; i < numComp; i++) {
102 final Pair<Double, T> comp = components.get(i);
103 weight[i] = comp.getFirst() / weightSum;
104 distribution.add(comp.getSecond());
105 }
106 }
107
108 /** {@inheritDoc} */
109 @Override
110 public double density(final double[] values) {
111 double p = 0;
112 for (int i = 0; i < weight.length; i++) {
113 p += weight[i] * distribution.get(i).density(values);
114 }
115 return p;
116 }
117
118 /** {@inheritDoc} */
119 @Override
120 public double[] sample() {
121 // Sampled values.
122 double[] vals = null;
123
124 // Determine which component to sample from.
125 final double randomValue = random.nextDouble();
126 double sum = 0;
127
128 for (int i = 0; i < weight.length; i++) {
129 sum += weight[i];
130 if (randomValue <= sum) {
131 // pick model i
132 vals = distribution.get(i).sample();
133 break;
134 }
135 }
136
137 if (vals == null) {
138 // This should never happen, but it ensures we won't return a null in
139 // case the loop above has some floating point inequality problem on
140 // the final iteration.
141 vals = distribution.get(weight.length - 1).sample();
142 }
143
144 return vals;
145 }
146
147 /** {@inheritDoc} */
148 @Override
149 public void reseedRandomGenerator(long seed) {
150 // Seed needs to be propagated to underlying components
151 // in order to maintain consistency between runs.
152 super.reseedRandomGenerator(seed);
153
154 for (int i = 0; i < distribution.size(); i++) {
155 // Make each component's seed different in order to avoid
156 // using the same sequence of random numbers.
157 distribution.get(i).reseedRandomGenerator(i + 1 + seed);
158 }
159 }
160
161 /**
162 * Gets the distributions that make up the mixture model.
163 *
164 * @return the component distributions and associated weights.
165 */
166 public List<Pair<Double, T>> getComponents() {
167 final List<Pair<Double, T>> list = new ArrayList<>(weight.length);
168
169 for (int i = 0; i < weight.length; i++) {
170 list.add(new Pair<>(weight[i], distribution.get(i)));
171 }
172
173 return list;
174 }
175 }