1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.hipparchus.stat.projection;
18
19 import org.hipparchus.exception.MathIllegalStateException;
20 import org.hipparchus.linear.EigenDecompositionSymmetric;
21 import org.hipparchus.linear.MatrixUtils;
22 import org.hipparchus.linear.RealMatrix;
23 import org.hipparchus.stat.LocalizedStatFormats;
24 import org.hipparchus.stat.StatUtils;
25 import org.hipparchus.stat.correlation.Covariance;
26 import org.hipparchus.stat.descriptive.moment.StandardDeviation;
27
28
29
30
31
32
33
34
35 public class PCA {
36
37
38
39 private final int numC;
40
41
42
43
44 private final boolean scale;
45
46
47
48
49 private final boolean biasCorrection;
50
51
52
53
54 private double[] center;
55
56
57
58
59 private double[] std;
60
61
62
63
64 private double[] eigenValues;
65
66
67
68
69 private RealMatrix principalComponents;
70
71
72
73
74 private final StandardDeviation sd;
75
76
77
78
79
80
81
82
83 public PCA(int numC, boolean scale, boolean biasCorrection) {
84 this.numC = numC;
85 this.scale = scale;
86 this.biasCorrection = biasCorrection;
87 sd = scale ? new StandardDeviation(biasCorrection) : null;
88 }
89
90
91
92
93
94
95 public PCA(int numC) {
96 this(numC, false, true);
97 }
98
99
100
101
102 public int getNumComponents() {
103 return numC;
104 }
105
106
107
108
109 public boolean isScale() {
110 return scale;
111 }
112
113
114
115
116 public boolean isBiasCorrection() {
117 return biasCorrection;
118 }
119
120
121
122
123 public double[] getVariance() {
124 validateState("getVariance");
125 return eigenValues.clone();
126 }
127
128
129
130
131 public double[] getCenter() {
132 validateState("getCenter");
133 return center.clone();
134 }
135
136
137
138
139
140
141
142 public double[][] getComponents() {
143 validateState("getComponents");
144 return principalComponents.getData();
145 }
146
147
148
149
150
151
152
153 public double[][] fitAndTransform(double[][] data) {
154 center = null;
155 RealMatrix normalizedM = getNormalizedMatrix(data);
156 calculatePrincipalComponents(normalizedM);
157 return normalizedM.multiply(principalComponents).getData();
158 }
159
160
161
162
163
164
165
166 public double[][] transform(double[][] data) {
167 validateState("transform");
168 RealMatrix normalizedM = getNormalizedMatrix(data);
169 return normalizedM.multiply(principalComponents).getData();
170 }
171
172
173
174
175
176
177
178 public PCA fit(double[][] data) {
179 center = null;
180 RealMatrix normalized = getNormalizedMatrix(data);
181 calculatePrincipalComponents(normalized);
182 return this;
183 }
184
185
186
187
188
189 private void validateState(String from) {
190 if (center == null) {
191 throw new MathIllegalStateException(LocalizedStatFormats.ILLEGAL_STATE_PCA, from);
192 }
193
194 }
195
196
197
198
199
200
201
202 private void calculatePrincipalComponents(RealMatrix normalizedM) {
203 RealMatrix covarianceM = new Covariance(normalizedM).getCovarianceMatrix();
204 EigenDecompositionSymmetric decomposition = new EigenDecompositionSymmetric(covarianceM);
205 eigenValues = decomposition.getEigenvalues();
206 principalComponents = MatrixUtils.createRealMatrix(eigenValues.length, numC);
207 for (int c = 0; c < numC; c++) {
208 for (int f = 0; f < eigenValues.length; f++) {
209 principalComponents.setEntry(f, c, decomposition.getEigenvector(c).getEntry(f));
210 }
211 }
212 }
213
214
215
216
217
218
219
220
221 private RealMatrix getNormalizedMatrix(double[][] input) {
222 int numS = input.length;
223 int numF = input[0].length;
224 boolean calculating = center == null;
225 if (calculating) {
226 center = new double[numF];
227 if (scale) {
228 std = new double[numF];
229 }
230 }
231
232 double[][] normalized = new double[numS][numF];
233 for (int f = 0; f < numF; f++) {
234 if (calculating) {
235 calculateNormalizeParameters(input, numS, f);
236 }
237 for (int s = 0; s < numS; s++) {
238 normalized[s][f] = input[s][f] - center[f];
239 }
240 if (scale) {
241 for (int s = 0; s < numS; s++) {
242 normalized[s][f] /= std[f];
243 }
244 }
245 }
246
247 return MatrixUtils.createRealMatrix(normalized);
248 }
249
250
251
252
253
254
255 private void calculateNormalizeParameters(double[][] input, int numS, int f) {
256 double[] column = new double[numS];
257 for (int s = 0; s < numS; s++) {
258 column[s] = input[s][f];
259 }
260 center[f] = StatUtils.mean(column);
261 if (scale) {
262 std[f] = sd.evaluate(column, center[f]);
263 }
264 }
265 }