1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22 package org.hipparchus.fitting;
23
24 import java.util.ArrayList;
25 import java.util.Collection;
26 import java.util.Collections;
27 import java.util.Comparator;
28 import java.util.List;
29
30 import org.hipparchus.analysis.function.Gaussian;
31 import org.hipparchus.exception.LocalizedCoreFormats;
32 import org.hipparchus.exception.MathIllegalArgumentException;
33 import org.hipparchus.linear.DiagonalMatrix;
34 import org.hipparchus.optim.nonlinear.vector.leastsquares.LeastSquaresBuilder;
35 import org.hipparchus.optim.nonlinear.vector.leastsquares.LeastSquaresProblem;
36 import org.hipparchus.util.FastMath;
37 import org.hipparchus.util.MathUtils;
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73 public class GaussianCurveFitter extends AbstractCurveFitter {
74
75 private static final Gaussian.Parametric FUNCTION = new Gaussian.Parametric() {
76
77 @Override
78 public double value(double x, double ... p) {
79 double v = Double.POSITIVE_INFINITY;
80 try {
81 v = super.value(x, p);
82 } catch (MathIllegalArgumentException e) {
83
84 }
85 return v;
86 }
87
88
89 @Override
90 public double[] gradient(double x, double ... p) {
91 double[] v = { Double.POSITIVE_INFINITY,
92 Double.POSITIVE_INFINITY,
93 Double.POSITIVE_INFINITY };
94 try {
95 v = super.gradient(x, p);
96 } catch (MathIllegalArgumentException e) {
97
98 }
99 return v;
100 }
101 };
102
103 private final double[] initialGuess;
104
105 private final int maxIter;
106
107
108
109
110
111
112
113
114 private GaussianCurveFitter(double[] initialGuess, int maxIter) {
115 this.initialGuess = initialGuess == null ? null : initialGuess.clone();
116 this.maxIter = maxIter;
117 }
118
119
120
121
122
123
124
125
126
127
128
129
130 public static GaussianCurveFitter create() {
131 return new GaussianCurveFitter(null, Integer.MAX_VALUE);
132 }
133
134
135
136
137
138
139 public GaussianCurveFitter withStartPoint(double[] newStart) {
140 return new GaussianCurveFitter(newStart.clone(),
141 maxIter);
142 }
143
144
145
146
147
148
149 public GaussianCurveFitter withMaxIterations(int newMaxIter) {
150 return new GaussianCurveFitter(initialGuess,
151 newMaxIter);
152 }
153
154
155 @Override
156 protected LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> observations) {
157
158
159 final int len = observations.size();
160 final double[] target = new double[len];
161 final double[] weights = new double[len];
162
163 int i = 0;
164 for (WeightedObservedPoint obs : observations) {
165 target[i] = obs.getY();
166 weights[i] = obs.getWeight();
167 ++i;
168 }
169
170 final AbstractCurveFitter.TheoreticalValuesFunction model =
171 new AbstractCurveFitter.TheoreticalValuesFunction(FUNCTION, observations);
172
173 final double[] startPoint = initialGuess != null ?
174 initialGuess :
175
176 new ParameterGuesser(observations).guess();
177
178
179
180 return new LeastSquaresBuilder().
181 maxEvaluations(Integer.MAX_VALUE).
182 maxIterations(maxIter).
183 start(startPoint).
184 target(target).
185 weight(new DiagonalMatrix(weights)).
186 model(model.getModelFunction(), model.getModelFunctionJacobian()).
187 build();
188
189 }
190
191
192
193
194
195
196 public static class ParameterGuesser {
197
198 private final double norm;
199
200 private final double mean;
201
202 private final double sigma;
203
204
205
206
207
208
209
210
211
212
213
214 public ParameterGuesser(Collection<WeightedObservedPoint> observations) {
215 MathUtils.checkNotNull(observations);
216 if (observations.size() < 3) {
217 throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL,
218 observations.size(), 3);
219 }
220
221 final List<WeightedObservedPoint> sorted = sortObservations(observations);
222 final double[] params = basicGuess(sorted.toArray(new WeightedObservedPoint[0]));
223
224 norm = params[0];
225 mean = params[1];
226 sigma = params[2];
227 }
228
229
230
231
232
233
234
235
236
237
238
239 public double[] guess() {
240 return new double[] { norm, mean, sigma };
241 }
242
243
244
245
246
247
248
249 private List<WeightedObservedPoint> sortObservations(Collection<WeightedObservedPoint> unsorted) {
250 final List<WeightedObservedPoint> observations = new ArrayList<>(unsorted);
251
252 final Comparator<WeightedObservedPoint> cmp = new Comparator<WeightedObservedPoint>() {
253
254 @Override
255 public int compare(WeightedObservedPoint p1,
256 WeightedObservedPoint p2) {
257 if (p1 == null && p2 == null) {
258 return 0;
259 }
260 if (p1 == null) {
261 return -1;
262 }
263 if (p2 == null) {
264 return 1;
265 }
266 int comp = Double.compare(p1.getX(), p2.getX());
267 if (comp != 0) {
268 return comp;
269 }
270 comp = Double.compare(p1.getY(), p2.getY());
271 if (comp != 0) {
272 return comp;
273 }
274 comp = Double.compare(p1.getWeight(), p2.getWeight());
275 if (comp != 0) {
276 return comp;
277 }
278 return 0;
279 }
280 };
281
282 Collections.sort(observations, cmp);
283 return observations;
284 }
285
286
287
288
289
290
291
292
293 private double[] basicGuess(WeightedObservedPoint[] points) {
294 final int maxYIdx = findMaxY(points);
295 final double n = points[maxYIdx].getY();
296 final double m = points[maxYIdx].getX();
297
298 double fwhmApprox;
299 try {
300 final double halfY = n + ((m - n) / 2);
301 final double fwhmX1 = interpolateXAtY(points, maxYIdx, -1, halfY);
302 final double fwhmX2 = interpolateXAtY(points, maxYIdx, 1, halfY);
303 fwhmApprox = fwhmX2 - fwhmX1;
304 } catch (MathIllegalArgumentException e) {
305
306 fwhmApprox = points[points.length - 1].getX() - points[0].getX();
307 }
308 final double s = fwhmApprox / (2 * FastMath.sqrt(2 * FastMath.log(2)));
309
310 return new double[] { n, m, s };
311 }
312
313
314
315
316
317
318
319 private int findMaxY(WeightedObservedPoint[] points) {
320 int maxYIdx = 0;
321 for (int i = 1; i < points.length; i++) {
322 if (points[i].getY() > points[maxYIdx].getY()) {
323 maxYIdx = i;
324 }
325 }
326 return maxYIdx;
327 }
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343 private double interpolateXAtY(WeightedObservedPoint[] points,
344 int startIdx,
345 int idxStep,
346 double y)
347 throws MathIllegalArgumentException {
348 if (idxStep == 0) {
349 throw new MathIllegalArgumentException(LocalizedCoreFormats.ZERO_NOT_ALLOWED);
350 }
351 final WeightedObservedPoint[] twoPoints
352 = getInterpolationPointsForY(points, startIdx, idxStep, y);
353 final WeightedObservedPoint p1 = twoPoints[0];
354 final WeightedObservedPoint p2 = twoPoints[1];
355 if (p1.getY() == y) {
356 return p1.getX();
357 }
358 if (p2.getY() == y) {
359 return p2.getX();
360 }
361 return p1.getX() + (((y - p1.getY()) * (p2.getX() - p1.getX())) /
362 (p2.getY() - p1.getY()));
363 }
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380 private WeightedObservedPoint[] getInterpolationPointsForY(WeightedObservedPoint[] points,
381 int startIdx,
382 int idxStep,
383 double y)
384 throws MathIllegalArgumentException {
385 if (idxStep == 0) {
386 throw new MathIllegalArgumentException(LocalizedCoreFormats.ZERO_NOT_ALLOWED);
387 }
388 for (int i = startIdx;
389 idxStep < 0 ? i + idxStep >= 0 : i + idxStep < points.length;
390 i += idxStep) {
391 final WeightedObservedPoint p1 = points[i];
392 final WeightedObservedPoint p2 = points[i + idxStep];
393 if (isBetween(y, p1.getY(), p2.getY())) {
394 if (idxStep < 0) {
395 return new WeightedObservedPoint[] { p2, p1 };
396 } else {
397 return new WeightedObservedPoint[] { p1, p2 };
398 }
399 }
400 }
401
402
403
404
405 throw new MathIllegalArgumentException(LocalizedCoreFormats.OUT_OF_RANGE_SIMPLE,
406 y, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY);
407 }
408
409
410
411
412
413
414
415
416
417
418
419 private boolean isBetween(double value,
420 double boundary1,
421 double boundary2) {
422 return (value >= boundary1 && value <= boundary2) ||
423 (value >= boundary2 && value <= boundary1);
424 }
425 }
426 }