1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * https://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 /*
19 * This is not the original file distributed by the Apache Software Foundation
20 * It has been modified by the Hipparchus project
21 */
22 package org.hipparchus.optim.nonlinear.vector.leastsquares;
23
24 import java.io.BufferedReader;
25 import java.io.IOException;
26 import java.util.ArrayList;
27
28 import org.hipparchus.analysis.MultivariateMatrixFunction;
29 import org.hipparchus.analysis.MultivariateVectorFunction;
30
31 /**
32 * This class gives access to the statistical reference datasets provided by the
33 * NIST (available
34 * <a href="http://www.itl.nist.gov/div898/strd/general/dataarchive.html">here</a>).
35 * Instances of this class can be created by invocation of the
36 * {@link StatisticalReferenceDatasetFactory}.
37 */
38 public abstract class StatisticalReferenceDataset {
39 /** The name of this dataset. */
40 private final String name;
41 /** The total number of observations (data points). */
42 private final int numObservations;
43 /** The total number of parameters. */
44 private final int numParameters;
45 /** The total number of starting points for the optimizations. */
46 private final int numStartingPoints;
47 /** The values of the predictor. */
48 private final double[] x;
49 /** The values of the response. */
50 private final double[] y;
51 /**
52 * The starting values. {@code startingValues[j][i]} is the value of the
53 * {@code i}-th parameter in the {@code j}-th set of starting values.
54 */
55 private final double[][] startingValues;
56 /** The certified values of the parameters. */
57 private final double[] a;
58 /** The certified values of the standard deviation of the parameters. */
59 private final double[] sigA;
60 /** The certified value of the residual sum of squares. */
61 private double residualSumOfSquares;
62 /** The least-squares problem. */
63 private final LeastSquaresProblem problem;
64
65 /**
66 * Creates a new instance of this class from the specified data file. The
67 * file must follow the StRD format.
68 *
69 * @param in the data file
70 * @throws IOException if an I/O error occurs
71 */
72 public StatisticalReferenceDataset(final BufferedReader in)
73 throws IOException {
74
75 final ArrayList<String> lines = new ArrayList<String>();
76 for (String line = in.readLine(); line != null; line = in.readLine()) {
77 lines.add(line);
78 }
79 int[] index = findLineNumbers("Data", lines);
80 if (index == null) {
81 throw new AssertionError("could not find line indices for data");
82 }
83 this.numObservations = index[1] - index[0] + 1;
84 this.x = new double[this.numObservations];
85 this.y = new double[this.numObservations];
86 for (int i = 0; i < this.numObservations; i++) {
87 final String line = lines.get(index[0] + i - 1);
88 final String[] tokens = line.trim().split(" ++");
89 // Data columns are in reverse order!!!
90 this.y[i] = Double.parseDouble(tokens[0]);
91 this.x[i] = Double.parseDouble(tokens[1]);
92 }
93
94 index = findLineNumbers("Starting Values", lines);
95 if (index == null) {
96 throw new AssertionError(
97 "could not find line indices for starting values");
98 }
99 this.numParameters = index[1] - index[0] + 1;
100
101 double[][] start = null;
102 this.a = new double[numParameters];
103 this.sigA = new double[numParameters];
104 for (int i = 0; i < numParameters; i++) {
105 final String line = lines.get(index[0] + i - 1);
106 final String[] tokens = line.trim().split(" ++");
107 if (start == null) {
108 start = new double[tokens.length - 4][numParameters];
109 }
110 for (int j = 2; j < tokens.length - 2; j++) {
111 start[j - 2][i] = Double.parseDouble(tokens[j]);
112 }
113 this.a[i] = Double.parseDouble(tokens[tokens.length - 2]);
114 this.sigA[i] = Double.parseDouble(tokens[tokens.length - 1]);
115 }
116 if (start == null) {
117 throw new IOException("could not find starting values");
118 }
119 this.numStartingPoints = start.length;
120 this.startingValues = start;
121
122 double dummyDouble = Double.NaN;
123 String dummyString = null;
124 for (String line : lines) {
125 if (line.contains("Dataset Name:")) {
126 dummyString = line
127 .substring(line.indexOf("Dataset Name:") + 13,
128 line.indexOf("(")).trim();
129 }
130 if (line.contains("Residual Sum of Squares")) {
131 final String[] tokens = line.split(" ++");
132 dummyDouble = Double.parseDouble(tokens[4].trim());
133 }
134 }
135 if (Double.isNaN(dummyDouble)) {
136 throw new IOException(
137 "could not find certified value of residual sum of squares");
138 }
139 this.residualSumOfSquares = dummyDouble;
140
141 if (dummyString == null) {
142 throw new IOException("could not find dataset name");
143 }
144 this.name = dummyString;
145
146 this.problem = new LeastSquaresProblem();
147 }
148
149 class LeastSquaresProblem {
150 public MultivariateVectorFunction getModelFunction() {
151 return new MultivariateVectorFunction() {
152 public double[] value(final double[] a) {
153 final int n = getNumObservations();
154 final double[] yhat = new double[n];
155 for (int i = 0; i < n; i++) {
156 yhat[i] = getModelValue(getX(i), a);
157 }
158 return yhat;
159 }
160 };
161 }
162
163 public MultivariateMatrixFunction getModelFunctionJacobian() {
164 return new MultivariateMatrixFunction() {
165 public double[][] value(final double[] a)
166 throws IllegalArgumentException {
167 final int n = getNumObservations();
168 final double[][] j = new double[n][];
169 for (int i = 0; i < n; i++) {
170 j[i] = getModelDerivatives(getX(i), a);
171 }
172 return j;
173 }
174 };
175 }
176 }
177
178 /**
179 * Returns the name of this dataset.
180 *
181 * @return the name of the dataset
182 */
183 public String getName() {
184 return name;
185 }
186
187 /**
188 * Returns the total number of observations (data points).
189 *
190 * @return the number of observations
191 */
192 public int getNumObservations() {
193 return numObservations;
194 }
195
196 /**
197 * Returns a copy of the data arrays. The data is laid out as follows <li>
198 * {@code data[0][i] = x[i]},</li> <li>{@code data[1][i] = y[i]},</li>
199 *
200 * @return the array of data points.
201 */
202 public double[][] getData() {
203 return new double[][] { x.clone(), y.clone() };
204 }
205
206 /**
207 * Returns the x-value of the {@code i}-th data point.
208 *
209 * @param i the index of the data point
210 * @return the x-value
211 */
212 public double getX(final int i) {
213 return x[i];
214 }
215
216 /**
217 * Returns the y-value of the {@code i}-th data point.
218 *
219 * @param i the index of the data point
220 * @return the y-value
221 */
222 public double getY(final int i) {
223 return y[i];
224 }
225
226 /**
227 * Returns the total number of parameters.
228 *
229 * @return the number of parameters
230 */
231 public int getNumParameters() {
232 return numParameters;
233 }
234
235 /**
236 * Returns the certified values of the paramters.
237 *
238 * @return the values of the parameters
239 */
240 public double[] getParameters() {
241 return a.clone();
242 }
243
244 /**
245 * Returns the certified value of the {@code i}-th parameter.
246 *
247 * @param i the index of the parameter
248 * @return the value of the parameter
249 */
250 public double getParameter(final int i) {
251 return a[i];
252 }
253
254 /**
255 * Returns the certified values of the standard deviations of the parameters.
256 *
257 * @return the standard deviations of the parameters
258 */
259 public double[] getParametersStandardDeviations() {
260 return sigA.clone();
261 }
262
263 /**
264 * Returns the certified value of the standard deviation of the {@code i}-th
265 * parameter.
266 *
267 * @param i the index of the parameter
268 * @return the standard deviation of the parameter
269 */
270 public double getParameterStandardDeviation(final int i) {
271 return sigA[i];
272 }
273
274 /**
275 * Returns the certified value of the residual sum of squares.
276 *
277 * @return the residual sum of squares
278 */
279 public double getResidualSumOfSquares() {
280 return residualSumOfSquares;
281 }
282
283 /**
284 * Returns the total number of starting points (initial guesses for the
285 * optimization process).
286 *
287 * @return the number of starting points
288 */
289 public int getNumStartingPoints() {
290 return numStartingPoints;
291 }
292
293 /**
294 * Returns the {@code i}-th set of initial values of the parameters.
295 *
296 * @param i the index of the starting point
297 * @return the starting point
298 */
299 public double[] getStartingPoint(final int i) {
300 return startingValues[i].clone();
301 }
302
303 /**
304 * Returns the least-squares problem corresponding to fitting the model to
305 * the specified data.
306 *
307 * @return the least-squares problem
308 */
309 public LeastSquaresProblem getLeastSquaresProblem() {
310 return problem;
311 }
312
313 /**
314 * Returns the value of the model for the specified values of the predictor
315 * variable and the parameters.
316 *
317 * @param x the predictor variable
318 * @param a the parameters
319 * @return the value of the model
320 */
321 public abstract double getModelValue(final double x, final double[] a);
322
323 /**
324 * Returns the values of the partial derivatives of the model with respect
325 * to the parameters.
326 *
327 * @param x the predictor variable
328 * @param a the parameters
329 * @return the partial derivatives
330 */
331 public abstract double[] getModelDerivatives(final double x,
332 final double[] a);
333
334 /**
335 * <p>
336 * Parses the specified text lines, and extracts the indices of the first
337 * and last lines of the data defined by the specified {@code key}. This key
338 * must be one of
339 * </p>
340 * <ul>
341 * <li>{@code "Starting Values"},</li>
342 * <li>{@code "Certified Values"},</li>
343 * <li>{@code "Data"}.</li>
344 * </ul>
345 * <p>
346 * In the NIST data files, the line indices are separated by the keywords
347 * {@code "lines"} and {@code "to"}.
348 * </p>
349 *
350 * @param lines the line of text to be parsed
351 * @return an array of two {@code int}s. First value is the index of the
352 * first line, second value is the index of the last line.
353 * {@code null} if the line could not be parsed.
354 */
355 private static int[] findLineNumbers(final String key,
356 final Iterable<String> lines) {
357 for (String text : lines) {
358 boolean flag = text.contains(key) && text.contains("lines") &&
359 text.contains("to") && text.contains(")");
360 if (flag) {
361 final int[] numbers = new int[2];
362 final String from = text.substring(text.indexOf("lines") + 5,
363 text.indexOf("to"));
364 numbers[0] = Integer.parseInt(from.trim());
365 final String to = text.substring(text.indexOf("to") + 2,
366 text.indexOf(")"));
367 numbers[1] = Integer.parseInt(to.trim());
368 return numbers;
369 }
370 }
371 return null;
372 }
373 }