1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23 package org.hipparchus.analysis;
24
25 import org.hipparchus.analysis.differentiation.DSFactory;
26 import org.hipparchus.analysis.differentiation.Derivative;
27 import org.hipparchus.analysis.differentiation.DerivativeStructure;
28 import org.hipparchus.analysis.differentiation.MultivariateDifferentiableFunction;
29 import org.hipparchus.analysis.differentiation.UnivariateDifferentiableFunction;
30 import org.hipparchus.analysis.function.Identity;
31 import org.hipparchus.exception.LocalizedCoreFormats;
32 import org.hipparchus.exception.MathIllegalArgumentException;
33 import org.hipparchus.util.MathArrays;
34 import org.hipparchus.util.MathUtils;
35
36
37
38
39
40 public class FunctionUtils {
41
42
43
44 private FunctionUtils() {}
45
46
47
48
49
50
51
52
53
54
55 public static UnivariateFunction compose(final UnivariateFunction ... f) {
56 return new UnivariateFunction() {
57
58 @Override
59 public double value(double x) {
60 double r = x;
61 for (int i = f.length - 1; i >= 0; i--) {
62 r = f[i].value(r);
63 }
64 return r;
65 }
66 };
67 }
68
69
70
71
72
73
74
75
76
77
78 public static UnivariateDifferentiableFunction compose(final UnivariateDifferentiableFunction ... f) {
79 return new UnivariateDifferentiableFunction() {
80
81
82 @Override
83 public double value(final double t) {
84 double r = t;
85 for (int i = f.length - 1; i >= 0; i--) {
86 r = f[i].value(r);
87 }
88 return r;
89 }
90
91
92 @Override
93 public <T extends Derivative<T>> T value(final T t) {
94 T r = t;
95 for (int i = f.length - 1; i >= 0; i--) {
96 r = f[i].value(r);
97 }
98 return r;
99 }
100
101 };
102 }
103
104
105
106
107
108
109
110 public static UnivariateFunction add(final UnivariateFunction ... f) {
111 return new UnivariateFunction() {
112
113 @Override
114 public double value(double x) {
115 double r = f[0].value(x);
116 for (int i = 1; i < f.length; i++) {
117 r += f[i].value(x);
118 }
119 return r;
120 }
121 };
122 }
123
124
125
126
127
128
129
130 public static UnivariateDifferentiableFunction add(final UnivariateDifferentiableFunction ... f) {
131 return new UnivariateDifferentiableFunction() {
132
133
134 @Override
135 public double value(final double t) {
136 double r = f[0].value(t);
137 for (int i = 1; i < f.length; i++) {
138 r += f[i].value(t);
139 }
140 return r;
141 }
142
143
144
145
146 @Override
147 public <T extends Derivative<T>> T value(final T t)
148 throws MathIllegalArgumentException {
149 T r = f[0].value(t);
150 for (int i = 1; i < f.length; i++) {
151 r = r.add(f[i].value(t));
152 }
153 return r;
154 }
155
156 };
157 }
158
159
160
161
162
163
164
165 public static UnivariateFunction multiply(final UnivariateFunction ... f) {
166 return new UnivariateFunction() {
167
168 @Override
169 public double value(double x) {
170 double r = f[0].value(x);
171 for (int i = 1; i < f.length; i++) {
172 r *= f[i].value(x);
173 }
174 return r;
175 }
176 };
177 }
178
179
180
181
182
183
184
185 public static UnivariateDifferentiableFunction multiply(final UnivariateDifferentiableFunction ... f) {
186 return new UnivariateDifferentiableFunction() {
187
188
189 @Override
190 public double value(final double t) {
191 double r = f[0].value(t);
192 for (int i = 1; i < f.length; i++) {
193 r *= f[i].value(t);
194 }
195 return r;
196 }
197
198
199 @Override
200 public <T extends Derivative<T>> T value(final T t) {
201 T r = f[0].value(t);
202 for (int i = 1; i < f.length; i++) {
203 r = r.multiply(f[i].value(t));
204 }
205 return r;
206 }
207
208 };
209 }
210
211
212
213
214
215
216
217
218
219
220 public static UnivariateFunction combine(final BivariateFunction combiner,
221 final UnivariateFunction f,
222 final UnivariateFunction g) {
223 return new UnivariateFunction() {
224
225 @Override
226 public double value(double x) {
227 return combiner.value(f.value(x), g.value(x));
228 }
229 };
230 }
231
232
233
234
235
236
237
238
239
240
241
242 public static MultivariateFunction collector(final BivariateFunction combiner,
243 final UnivariateFunction f,
244 final double initialValue) {
245 return new MultivariateFunction() {
246
247 @Override
248 public double value(double[] point) {
249 double result = combiner.value(initialValue, f.value(point[0]));
250 for (int i = 1; i < point.length; i++) {
251 result = combiner.value(result, f.value(point[i]));
252 }
253 return result;
254 }
255 };
256 }
257
258
259
260
261
262
263
264
265
266
267 public static MultivariateFunction collector(final BivariateFunction combiner,
268 final double initialValue) {
269 return collector(combiner, new Identity(), initialValue);
270 }
271
272
273
274
275
276
277
278
279 public static UnivariateFunction fix1stArgument(final BivariateFunction f,
280 final double fixed) {
281 return new UnivariateFunction() {
282
283 @Override
284 public double value(double x) {
285 return f.value(fixed, x);
286 }
287 };
288 }
289
290
291
292
293
294
295
296 public static UnivariateFunction fix2ndArgument(final BivariateFunction f,
297 final double fixed) {
298 return new UnivariateFunction() {
299
300 @Override
301 public double value(double x) {
302 return f.value(x, fixed);
303 }
304 };
305 }
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324 public static double[] sample(UnivariateFunction f, double min, double max, int n)
325 throws MathIllegalArgumentException {
326
327 if (n <= 0) {
328 throw new MathIllegalArgumentException(
329 LocalizedCoreFormats.NOT_POSITIVE_NUMBER_OF_SAMPLES,
330 Integer.valueOf(n));
331 }
332 if (min >= max) {
333 throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_LARGE_BOUND_EXCLUDED,
334 min, max);
335 }
336
337 final double[] s = new double[n];
338 final double h = (max - min) / n;
339 for (int i = 0; i < n; i++) {
340 s[i] = f.value(min + i * h);
341 }
342 return s;
343 }
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369 public static UnivariateDifferentiableFunction toDifferentiable(final UnivariateFunction f,
370 final UnivariateFunction ... derivatives) {
371
372 return new UnivariateDifferentiableFunction() {
373
374
375 @Override
376 public double value(final double x) {
377 return f.value(x);
378 }
379
380
381 @Override
382 public <T extends Derivative<T>> T value(final T x) {
383 if (x.getOrder() > derivatives.length) {
384 throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_LARGE,
385 x.getOrder(), derivatives.length);
386 }
387 final double[] packed = new double[x.getOrder() + 1];
388 packed[0] = f.value(x.getValue());
389 for (int i = 0; i < x.getOrder(); ++i) {
390 packed[i + 1] = derivatives[i].value(x.getValue());
391 }
392 return x.compose(packed);
393 }
394
395 };
396
397 }
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423 public static MultivariateDifferentiableFunction toDifferentiable(final MultivariateFunction f,
424 final MultivariateVectorFunction gradient) {
425
426 return new MultivariateDifferentiableFunction() {
427
428
429 @Override
430 public double value(final double[] point) {
431 return f.value(point);
432 }
433
434
435 @Override
436 public DerivativeStructure value(final DerivativeStructure[] point) {
437
438
439 final double[] dPoint = new double[point.length];
440 for (int i = 0; i < point.length; ++i) {
441 dPoint[i] = point[i].getValue();
442 if (point[i].getOrder() > 1) {
443 throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_LARGE,
444 point[i].getOrder(), 1);
445 }
446 }
447
448
449 final double v = f.value(dPoint);
450 final double[] dv = gradient.value(dPoint);
451 MathUtils.checkDimension(dv.length, point.length);
452
453
454 final int parameters = point[0].getFreeParameters();
455 final double[] partials = new double[point.length];
456 final double[] packed = new double[parameters + 1];
457 packed[0] = v;
458 final int[] orders = new int[parameters];
459 for (int i = 0; i < parameters; ++i) {
460
461
462 orders[i] = 1;
463 for (int j = 0; j < point.length; ++j) {
464 partials[j] = point[j].getPartialDerivative(orders);
465 }
466 orders[i] = 0;
467
468
469 packed[i + 1] = MathArrays.linearCombination(dv, partials);
470
471 }
472
473 return point[0].getFactory().build(packed);
474
475 }
476
477 };
478
479 }
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494 public static UnivariateFunction derivative(final UnivariateDifferentiableFunction f, final int order) {
495
496 final DSFactory factory = new DSFactory(1, order);
497
498 return new UnivariateFunction() {
499
500
501 @Override
502 public double value(final double x) {
503 final DerivativeStructure dsX = factory.variable(0, x);
504 return f.value(dsX).getPartialDerivative(order);
505 }
506
507 };
508 }
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523 public static MultivariateFunction derivative(final MultivariateDifferentiableFunction f, final int[] orders) {
524
525
526 int sum = 0;
527 for (final int order : orders) {
528 sum += order;
529 }
530 final int sumOrders = sum;
531
532 return new MultivariateFunction() {
533
534
535 private DSFactory factory;
536
537
538 @Override
539 public double value(final double[] point) {
540
541 if (factory == null || point.length != factory.getCompiler().getFreeParameters()) {
542
543 factory = new DSFactory(point.length, sumOrders);
544 }
545
546
547 final DerivativeStructure[] dsPoint = new DerivativeStructure[point.length];
548 for (int i = 0; i < point.length; ++i) {
549 dsPoint[i] = factory.variable(i, point[i]);
550 }
551
552 return f.value(dsPoint).getPartialDerivative(orders);
553
554 }
555
556 };
557 }
558
559 }