1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22 package org.hipparchus.distribution.multivariate;
23
24 import java.util.ArrayList;
25 import java.util.List;
26
27 import org.hipparchus.distribution.MultivariateRealDistribution;
28 import org.hipparchus.exception.LocalizedCoreFormats;
29 import org.hipparchus.exception.MathIllegalArgumentException;
30 import org.hipparchus.exception.MathRuntimeException;
31 import org.hipparchus.random.RandomGenerator;
32 import org.hipparchus.random.Well19937c;
33 import org.hipparchus.util.Pair;
34
35
36
37
38
39
40
41 public class MixtureMultivariateRealDistribution<T extends MultivariateRealDistribution>
42 extends AbstractMultivariateRealDistribution {
43
44 private final double[] weight;
45
46 private final List<T> distribution;
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61 public MixtureMultivariateRealDistribution(List<Pair<Double, T>> components) {
62 this(new Well19937c(), components);
63 }
64
65
66
67
68
69
70
71
72
73
74
75 public MixtureMultivariateRealDistribution(RandomGenerator rng,
76 List<Pair<Double, T>> components) {
77 super(rng, components.get(0).getSecond().getDimension());
78
79 final int numComp = components.size();
80 final int dim = getDimension();
81 double weightSum = 0;
82 for (int i = 0; i < numComp; i++) {
83 final Pair<Double, T> comp = components.get(i);
84 if (comp.getSecond().getDimension() != dim) {
85 throw new MathIllegalArgumentException(LocalizedCoreFormats.DIMENSIONS_MISMATCH,
86 comp.getSecond().getDimension(), dim);
87 }
88 if (comp.getFirst() < 0) {
89 throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL, comp.getFirst(), 0);
90 }
91 weightSum += comp.getFirst();
92 }
93
94
95 if (Double.isInfinite(weightSum)) {
96 throw new MathRuntimeException(LocalizedCoreFormats.OVERFLOW);
97 }
98
99
100 distribution = new ArrayList<>();
101 weight = new double[numComp];
102 for (int i = 0; i < numComp; i++) {
103 final Pair<Double, T> comp = components.get(i);
104 weight[i] = comp.getFirst() / weightSum;
105 distribution.add(comp.getSecond());
106 }
107 }
108
109
110 @Override
111 public double density(final double[] values) {
112 double p = 0;
113 for (int i = 0; i < weight.length; i++) {
114 p += weight[i] * distribution.get(i).density(values);
115 }
116 return p;
117 }
118
119
120 @Override
121 public double[] sample() {
122
123 double[] vals = null;
124
125
126 final double randomValue = random.nextDouble();
127 double sum = 0;
128
129 for (int i = 0; i < weight.length; i++) {
130 sum += weight[i];
131 if (randomValue <= sum) {
132
133 vals = distribution.get(i).sample();
134 break;
135 }
136 }
137
138 if (vals == null) {
139
140
141
142 vals = distribution.get(weight.length - 1).sample();
143 }
144
145 return vals;
146 }
147
148
149 @Override
150 public void reseedRandomGenerator(long seed) {
151
152
153 super.reseedRandomGenerator(seed);
154
155 for (int i = 0; i < distribution.size(); i++) {
156
157
158 distribution.get(i).reseedRandomGenerator(i + 1 + seed);
159 }
160 }
161
162
163
164
165
166
167 public List<Pair<Double, T>> getComponents() {
168 final List<Pair<Double, T>> list = new ArrayList<>(weight.length);
169
170 for (int i = 0; i < weight.length; i++) {
171 list.add(new Pair<>(weight[i], distribution.get(i)));
172 }
173
174 return list;
175 }
176 }