1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22 package org.hipparchus.optim.nonlinear.vector.leastsquares;
23
24 import java.io.BufferedReader;
25 import java.io.IOException;
26 import java.util.ArrayList;
27
28 import org.hipparchus.analysis.MultivariateMatrixFunction;
29 import org.hipparchus.analysis.MultivariateVectorFunction;
30
31
32
33
34
35
36
37
38 public abstract class StatisticalReferenceDataset {
39
40 private final String name;
41
42 private final int numObservations;
43
44 private final int numParameters;
45
46 private final int numStartingPoints;
47
48 private final double[] x;
49
50 private final double[] y;
51
52
53
54
55 private final double[][] startingValues;
56
57 private final double[] a;
58
59 private final double[] sigA;
60
61 private double residualSumOfSquares;
62
63 private final LeastSquaresProblem problem;
64
65
66
67
68
69
70
71
72 public StatisticalReferenceDataset(final BufferedReader in)
73 throws IOException {
74
75 final ArrayList<String> lines = new ArrayList<String>();
76 for (String line = in.readLine(); line != null; line = in.readLine()) {
77 lines.add(line);
78 }
79 int[] index = findLineNumbers("Data", lines);
80 if (index == null) {
81 throw new AssertionError("could not find line indices for data");
82 }
83 this.numObservations = index[1] - index[0] + 1;
84 this.x = new double[this.numObservations];
85 this.y = new double[this.numObservations];
86 for (int i = 0; i < this.numObservations; i++) {
87 final String line = lines.get(index[0] + i - 1);
88 final String[] tokens = line.trim().split(" ++");
89
90 this.y[i] = Double.parseDouble(tokens[0]);
91 this.x[i] = Double.parseDouble(tokens[1]);
92 }
93
94 index = findLineNumbers("Starting Values", lines);
95 if (index == null) {
96 throw new AssertionError(
97 "could not find line indices for starting values");
98 }
99 this.numParameters = index[1] - index[0] + 1;
100
101 double[][] start = null;
102 this.a = new double[numParameters];
103 this.sigA = new double[numParameters];
104 for (int i = 0; i < numParameters; i++) {
105 final String line = lines.get(index[0] + i - 1);
106 final String[] tokens = line.trim().split(" ++");
107 if (start == null) {
108 start = new double[tokens.length - 4][numParameters];
109 }
110 for (int j = 2; j < tokens.length - 2; j++) {
111 start[j - 2][i] = Double.parseDouble(tokens[j]);
112 }
113 this.a[i] = Double.parseDouble(tokens[tokens.length - 2]);
114 this.sigA[i] = Double.parseDouble(tokens[tokens.length - 1]);
115 }
116 if (start == null) {
117 throw new IOException("could not find starting values");
118 }
119 this.numStartingPoints = start.length;
120 this.startingValues = start;
121
122 double dummyDouble = Double.NaN;
123 String dummyString = null;
124 for (String line : lines) {
125 if (line.contains("Dataset Name:")) {
126 dummyString = line
127 .substring(line.indexOf("Dataset Name:") + 13,
128 line.indexOf("(")).trim();
129 }
130 if (line.contains("Residual Sum of Squares")) {
131 final String[] tokens = line.split(" ++");
132 dummyDouble = Double.parseDouble(tokens[4].trim());
133 }
134 }
135 if (Double.isNaN(dummyDouble)) {
136 throw new IOException(
137 "could not find certified value of residual sum of squares");
138 }
139 this.residualSumOfSquares = dummyDouble;
140
141 if (dummyString == null) {
142 throw new IOException("could not find dataset name");
143 }
144 this.name = dummyString;
145
146 this.problem = new LeastSquaresProblem();
147 }
148
149 class LeastSquaresProblem {
150 public MultivariateVectorFunction getModelFunction() {
151 return new MultivariateVectorFunction() {
152 public double[] value(final double[] a) {
153 final int n = getNumObservations();
154 final double[] yhat = new double[n];
155 for (int i = 0; i < n; i++) {
156 yhat[i] = getModelValue(getX(i), a);
157 }
158 return yhat;
159 }
160 };
161 }
162
163 public MultivariateMatrixFunction getModelFunctionJacobian() {
164 return new MultivariateMatrixFunction() {
165 public double[][] value(final double[] a)
166 throws IllegalArgumentException {
167 final int n = getNumObservations();
168 final double[][] j = new double[n][];
169 for (int i = 0; i < n; i++) {
170 j[i] = getModelDerivatives(getX(i), a);
171 }
172 return j;
173 }
174 };
175 }
176 }
177
178
179
180
181
182
183 public String getName() {
184 return name;
185 }
186
187
188
189
190
191
192 public int getNumObservations() {
193 return numObservations;
194 }
195
196
197
198
199
200
201
202 public double[][] getData() {
203 return new double[][] { x.clone(), y.clone() };
204 }
205
206
207
208
209
210
211
212 public double getX(final int i) {
213 return x[i];
214 }
215
216
217
218
219
220
221
222 public double getY(final int i) {
223 return y[i];
224 }
225
226
227
228
229
230
231 public int getNumParameters() {
232 return numParameters;
233 }
234
235
236
237
238
239
240 public double[] getParameters() {
241 return a.clone();
242 }
243
244
245
246
247
248
249
250 public double getParameter(final int i) {
251 return a[i];
252 }
253
254
255
256
257
258
259 public double[] getParametersStandardDeviations() {
260 return sigA.clone();
261 }
262
263
264
265
266
267
268
269
270 public double getParameterStandardDeviation(final int i) {
271 return sigA[i];
272 }
273
274
275
276
277
278
279 public double getResidualSumOfSquares() {
280 return residualSumOfSquares;
281 }
282
283
284
285
286
287
288
289 public int getNumStartingPoints() {
290 return numStartingPoints;
291 }
292
293
294
295
296
297
298
299 public double[] getStartingPoint(final int i) {
300 return startingValues[i].clone();
301 }
302
303
304
305
306
307
308
309 public LeastSquaresProblem getLeastSquaresProblem() {
310 return problem;
311 }
312
313
314
315
316
317
318
319
320
321 public abstract double getModelValue(final double x, final double[] a);
322
323
324
325
326
327
328
329
330
331 public abstract double[] getModelDerivatives(final double x,
332 final double[] a);
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355 private static int[] findLineNumbers(final String key,
356 final Iterable<String> lines) {
357 for (String text : lines) {
358 boolean flag = text.contains(key) && text.contains("lines") &&
359 text.contains("to") && text.contains(")");
360 if (flag) {
361 final int[] numbers = new int[2];
362 final String from = text.substring(text.indexOf("lines") + 5,
363 text.indexOf("to"));
364 numbers[0] = Integer.parseInt(from.trim());
365 final String to = text.substring(text.indexOf("to") + 2,
366 text.indexOf(")"));
367 numbers[1] = Integer.parseInt(to.trim());
368 return numbers;
369 }
370 }
371 return null;
372 }
373 }