1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23 package org.hipparchus.analysis;
24
25 import org.hipparchus.analysis.differentiation.DSFactory;
26 import org.hipparchus.analysis.differentiation.Derivative;
27 import org.hipparchus.analysis.differentiation.DerivativeStructure;
28 import org.hipparchus.analysis.differentiation.MultivariateDifferentiableFunction;
29 import org.hipparchus.analysis.differentiation.UnivariateDifferentiableFunction;
30 import org.hipparchus.analysis.function.Add;
31 import org.hipparchus.analysis.function.Constant;
32 import org.hipparchus.analysis.function.Cos;
33 import org.hipparchus.analysis.function.Cosh;
34 import org.hipparchus.analysis.function.Divide;
35 import org.hipparchus.analysis.function.Identity;
36 import org.hipparchus.analysis.function.Inverse;
37 import org.hipparchus.analysis.function.Log;
38 import org.hipparchus.analysis.function.Max;
39 import org.hipparchus.analysis.function.Min;
40 import org.hipparchus.analysis.function.Minus;
41 import org.hipparchus.analysis.function.Multiply;
42 import org.hipparchus.analysis.function.Pow;
43 import org.hipparchus.analysis.function.Power;
44 import org.hipparchus.analysis.function.Sin;
45 import org.hipparchus.analysis.function.Sinc;
46 import org.hipparchus.analysis.function.Subtract;
47 import org.hipparchus.exception.LocalizedCoreFormats;
48 import org.hipparchus.exception.MathIllegalArgumentException;
49 import org.hipparchus.util.FastMath;
50 import org.junit.Assert;
51 import org.junit.Test;
52
53
54
55
56 public class FunctionUtilsTest {
57 private final double EPS = FastMath.ulp(1d);
58
59 @Test
60 public void testCompose() {
61 UnivariateFunction id = new Identity();
62 Assert.assertEquals(3, FunctionUtils.compose(id, id, id).value(3), EPS);
63
64 UnivariateFunction c = new Constant(4);
65 Assert.assertEquals(4, FunctionUtils.compose(id, c).value(3), EPS);
66 Assert.assertEquals(4, FunctionUtils.compose(c, id).value(3), EPS);
67
68 UnivariateFunction m = new Minus();
69 Assert.assertEquals(-3, FunctionUtils.compose(m).value(3), EPS);
70 Assert.assertEquals(3, FunctionUtils.compose(m, m).value(3), EPS);
71
72 UnivariateFunction inv = new Inverse();
73 Assert.assertEquals(-0.25, FunctionUtils.compose(inv, m, c, id).value(3), EPS);
74
75 UnivariateFunction pow = new Power(2);
76 Assert.assertEquals(81, FunctionUtils.compose(pow, pow).value(3), EPS);
77 }
78
79 @Test
80 public void testComposeDifferentiable() {
81 DSFactory factory = new DSFactory(1, 1);
82 UnivariateDifferentiableFunction id = new Identity();
83 Assert.assertEquals(1, FunctionUtils.compose(id, id, id).value(factory.variable(0, 3)).getPartialDerivative(1), EPS);
84 Assert.assertEquals(1.5, FunctionUtils.compose(id, id, id).value(1.5), EPS);
85
86 UnivariateDifferentiableFunction c = new Constant(4);
87 Assert.assertEquals(0, FunctionUtils.compose(id, c).value(factory.variable(0, 3)).getPartialDerivative(1), EPS);
88 Assert.assertEquals(0, FunctionUtils.compose(c, id).value(factory.variable(0, 3)).getPartialDerivative(1), EPS);
89
90 UnivariateDifferentiableFunction m = new Minus();
91 Assert.assertEquals(-1, FunctionUtils.compose(m).value(factory.variable(0, 3)).getPartialDerivative(1), EPS);
92 Assert.assertEquals(1, FunctionUtils.compose(m, m).value(factory.variable(0, 3)).getPartialDerivative(1), EPS);
93
94 UnivariateDifferentiableFunction inv = new Inverse();
95 Assert.assertEquals(0.25, FunctionUtils.compose(inv, m, id).value(factory.variable(0, 2)).getPartialDerivative(1), EPS);
96
97 UnivariateDifferentiableFunction pow = new Power(2);
98 Assert.assertEquals(108, FunctionUtils.compose(pow, pow).value(factory.variable(0, 3)).getPartialDerivative(1), EPS);
99
100 UnivariateDifferentiableFunction log = new Log();
101 double a = 9876.54321;
102 Assert.assertEquals(pow.value(factory.variable(0, a)).getPartialDerivative(1) / pow.value(a),
103 FunctionUtils.compose(log, pow).value(factory.variable(0, a)).getPartialDerivative(1), EPS);
104 }
105
106 @Test
107 public void testAdd() {
108 UnivariateFunction id = new Identity();
109 UnivariateFunction c = new Constant(4);
110 UnivariateFunction m = new Minus();
111 UnivariateFunction inv = new Inverse();
112
113 Assert.assertEquals(4.5, FunctionUtils.add(inv, m, c, id).value(2), EPS);
114 Assert.assertEquals(4 + 2, FunctionUtils.add(c, id).value(2), EPS);
115 Assert.assertEquals(4 - 2, FunctionUtils.add(c, FunctionUtils.compose(m, id)).value(2), EPS);
116 }
117
118 @Test
119 public void testAddDifferentiable() {
120 UnivariateDifferentiableFunction sin = new Sin();
121 UnivariateDifferentiableFunction c = new Constant(4);
122 UnivariateDifferentiableFunction m = new Minus();
123 UnivariateDifferentiableFunction inv = new Inverse();
124
125 final double a = 123.456;
126 DSFactory factory = new DSFactory(1, 1);
127 Assert.assertEquals(- 1 / (a * a) -1 + FastMath.cos(a),
128 FunctionUtils.add(inv, m, c, sin).value(factory.variable(0, a)).getPartialDerivative(1),
129 EPS);
130 Assert.assertEquals(4 + FastMath.sin(1.2), FunctionUtils.add(sin, c).value(1.2), EPS);
131 }
132
133 @Test
134 public void testMultiply() {
135 UnivariateFunction c = new Constant(4);
136 Assert.assertEquals(16, FunctionUtils.multiply(c, c).value(12345), EPS);
137
138 UnivariateFunction inv = new Inverse();
139 UnivariateFunction pow = new Power(2);
140 Assert.assertEquals(1, FunctionUtils.multiply(FunctionUtils.compose(inv, pow), pow).value(3.5), EPS);
141 }
142
143 @Test
144 public void testMultiplyDifferentiable() {
145 UnivariateDifferentiableFunction c = new Constant(4);
146 UnivariateDifferentiableFunction id = new Identity();
147 DSFactory factory = new DSFactory(1, 1);
148 final double a = 1.2345678;
149 Assert.assertEquals(8 * a, FunctionUtils.multiply(c, id, id).value(factory.variable(0, a)).getPartialDerivative(1), EPS);
150
151 UnivariateDifferentiableFunction inv = new Inverse();
152 UnivariateDifferentiableFunction pow = new Power(2.5);
153 UnivariateDifferentiableFunction cos = new Cos();
154 Assert.assertEquals(1.5 * FastMath.sqrt(a) * FastMath.cos(a) - FastMath.pow(a, 1.5) * FastMath.sin(a),
155 FunctionUtils.multiply(inv, pow, cos).value(factory.variable(0, a)).getPartialDerivative(1), EPS);
156
157 UnivariateDifferentiableFunction cosh = new Cosh();
158 Assert.assertEquals(1.5 * FastMath.sqrt(a) * FastMath.cosh(a) + FastMath.pow(a, 1.5) * FastMath.sinh(a),
159 FunctionUtils.multiply(inv, pow, cosh).value(factory.variable(0, a)).getPartialDerivative(1), 8 * EPS);
160 Assert.assertEquals(16, FunctionUtils.multiply(c, c).value(FastMath.PI), EPS);
161 }
162
163 @Test
164 public void testCombine() {
165 BivariateFunction bi = new Subtract();
166 UnivariateFunction id = new Identity();
167 UnivariateFunction m = new Minus();
168 UnivariateFunction c = FunctionUtils.combine(bi, id, m);
169 Assert.assertEquals(4.6912, c.value(2.3456), EPS);
170
171 bi = new Multiply();
172 UnivariateFunction inv = new Inverse();
173 c = FunctionUtils.combine(bi, id, inv);
174 Assert.assertEquals(1, c.value(2.3456), EPS);
175 }
176
177 @Test
178 public void testCollector() {
179 BivariateFunction bi = new Add();
180 MultivariateFunction coll = FunctionUtils.collector(bi, 0);
181 Assert.assertEquals(10, coll.value(new double[] {1, 2, 3, 4}), EPS);
182
183 bi = new Multiply();
184 coll = FunctionUtils.collector(bi, 1);
185 Assert.assertEquals(24, coll.value(new double[] {1, 2, 3, 4}), EPS);
186
187 bi = new Max();
188 coll = FunctionUtils.collector(bi, Double.NEGATIVE_INFINITY);
189 Assert.assertEquals(10, coll.value(new double[] {1, -2, 7.5, 10, -24, 9.99}), 0);
190
191 bi = new Min();
192 coll = FunctionUtils.collector(bi, Double.POSITIVE_INFINITY);
193 Assert.assertEquals(-24, coll.value(new double[] {1, -2, 7.5, 10, -24, 9.99}), 0);
194 }
195
196 @Test
197 public void testSinc() {
198 BivariateFunction div = new Divide();
199 UnivariateFunction sin = new Sin();
200 UnivariateFunction id = new Identity();
201 UnivariateFunction sinc1 = FunctionUtils.combine(div, sin, id);
202 UnivariateFunction sinc2 = new Sinc();
203
204 for (int i = 0; i < 10; i++) {
205 double x = FastMath.random();
206 Assert.assertEquals(sinc1.value(x), sinc2.value(x), EPS);
207 }
208 }
209
210 @Test
211 public void testFixingArguments() {
212 UnivariateFunction scaler = FunctionUtils.fix1stArgument(new Multiply(), 10);
213 Assert.assertEquals(1.23456, scaler.value(0.123456), EPS);
214
215 UnivariateFunction pow1 = new Power(2);
216 UnivariateFunction pow2 = FunctionUtils.fix2ndArgument(new Pow(), 2);
217
218 for (int i = 0; i < 10; i++) {
219 double x = FastMath.random() * 10;
220 Assert.assertEquals(pow1.value(x), pow2.value(x), 0);
221 }
222 }
223
224 @Test(expected = MathIllegalArgumentException.class)
225 public void testSampleWrongBounds(){
226 FunctionUtils.sample(new Sin(), FastMath.PI, 0.0, 10);
227 }
228
229 @Test(expected = MathIllegalArgumentException.class)
230 public void testSampleNegativeNumberOfPoints(){
231 FunctionUtils.sample(new Sin(), 0.0, FastMath.PI, -1);
232 }
233
234 @Test(expected = MathIllegalArgumentException.class)
235 public void testSampleNullNumberOfPoints(){
236 FunctionUtils.sample(new Sin(), 0.0, FastMath.PI, 0);
237 }
238
239 @Test
240 public void testSample() {
241 final int n = 11;
242 final double min = 0.0;
243 final double max = FastMath.PI;
244 final double[] actual = FunctionUtils.sample(new Sin(), min, max, n);
245 for (int i = 0; i < n; i++) {
246 final double x = min + (max - min) / n * i;
247 Assert.assertEquals("x = " + x, FastMath.sin(x), actual[i], 0.0);
248 }
249 }
250
251 @Test
252 public void testToDifferentiableUnivariate() {
253
254 final UnivariateFunction f0 = new UnivariateFunction() {
255 @Override
256 public double value(final double x) {
257 return x * x;
258 }
259 };
260 final UnivariateFunction f1 = new UnivariateFunction() {
261 @Override
262 public double value(final double x) {
263 return 2 * x;
264 }
265 };
266 final UnivariateFunction f2 = new UnivariateFunction() {
267 @Override
268 public double value(final double x) {
269 return 2;
270 }
271 };
272 final UnivariateDifferentiableFunction f = FunctionUtils.toDifferentiable(f0, f1, f2);
273
274 DSFactory factory = new DSFactory(1, 2);
275 for (double t = -1.0; t < 1; t += 0.01) {
276
277 DerivativeStructure dsT = factory.variable(0, t);
278 DerivativeStructure y = f.value(dsT.sin());
279 Assert.assertEquals(FastMath.sin(t) * FastMath.sin(t), f.value(FastMath.sin(t)), 1.0e-15);
280 Assert.assertEquals(FastMath.sin(t) * FastMath.sin(t), y.getValue(), 1.0e-15);
281 Assert.assertEquals(2 * FastMath.cos(t) * FastMath.sin(t), y.getPartialDerivative(1), 1.0e-15);
282 Assert.assertEquals(2 * (1 - 2 * FastMath.sin(t) * FastMath.sin(t)), y.getPartialDerivative(2), 1.0e-15);
283 }
284
285 try {
286 f.value(new DSFactory(1, 3).constant(0.0));
287 Assert.fail("an exception should have been thrown");
288 } catch (MathIllegalArgumentException e) {
289 Assert.assertEquals(LocalizedCoreFormats.NUMBER_TOO_LARGE, e.getSpecifier());
290 Assert.assertEquals(2, ((Integer) e.getParts()[1]).intValue());
291 Assert.assertEquals(3, ((Integer) e.getParts()[0]).intValue());
292 }
293 }
294
295 @Test
296 public void testToDifferentiableMultivariate() {
297
298 final double a = 1.5;
299 final double b = 0.5;
300 final MultivariateFunction f = new MultivariateFunction() {
301 @Override
302 public double value(final double[] point) {
303 return a * point[0] + b * point[1];
304 }
305 };
306 final MultivariateVectorFunction gradient = new MultivariateVectorFunction() {
307 @Override
308 public double[] value(final double[] point) {
309 return new double[] { a, b };
310 }
311 };
312 final MultivariateDifferentiableFunction mdf = FunctionUtils.toDifferentiable(f, gradient);
313
314 DSFactory factory11 = new DSFactory(1, 1);
315 for (double t = -1.0; t < 1; t += 0.01) {
316
317 DerivativeStructure dsT = factory11.variable(0, t);
318 DerivativeStructure y = mdf.value(new DerivativeStructure[] { dsT.sin(), dsT.cos() });
319 Assert.assertEquals(a * FastMath.sin(t) + b * FastMath.cos(t), y.getValue(), 1.0e-15);
320 Assert.assertEquals(a * FastMath.cos(t) - b * FastMath.sin(t), y.getPartialDerivative(1), 1.0e-15);
321 }
322
323 DSFactory factory21 = new DSFactory(2, 1);
324 for (double u = -1.0; u < 1; u += 0.01) {
325 DerivativeStructure dsU = factory21.variable(0, u);
326 for (double v = -1.0; v < 1; v += 0.01) {
327 DerivativeStructure dsV = factory21.variable(1, v);
328 DerivativeStructure y = mdf.value(new DerivativeStructure[] { dsU, dsV });
329 Assert.assertEquals(a * u + b * v, mdf.value(new double[] { u, v }), 1.0e-15);
330 Assert.assertEquals(a * u + b * v, y.getValue(), 1.0e-15);
331 Assert.assertEquals(a, y.getPartialDerivative(1, 0), 1.0e-15);
332 Assert.assertEquals(b, y.getPartialDerivative(0, 1), 1.0e-15);
333 }
334 }
335
336 DSFactory factory13 = new DSFactory(1, 3);
337 try {
338 mdf.value(new DerivativeStructure[] { factory13.constant(0.0), factory13.constant(0.0) });
339 Assert.fail("an exception should have been thrown");
340 } catch (MathIllegalArgumentException e) {
341 Assert.assertEquals(LocalizedCoreFormats.NUMBER_TOO_LARGE, e.getSpecifier());
342 Assert.assertEquals(1, ((Integer) e.getParts()[1]).intValue());
343 Assert.assertEquals(3, ((Integer) e.getParts()[0]).intValue());
344 }
345 }
346
347 @Test
348 public void testToDifferentiableMultivariateInconsistentGradient() {
349
350 final double a = 1.5;
351 final double b = 0.5;
352 final MultivariateFunction f = new MultivariateFunction() {
353 @Override
354 public double value(final double[] point) {
355 return a * point[0] + b * point[1];
356 }
357 };
358 final MultivariateVectorFunction gradient = new MultivariateVectorFunction() {
359 @Override
360 public double[] value(final double[] point) {
361 return new double[] { a, b, 0.0 };
362 }
363 };
364 final MultivariateDifferentiableFunction mdf = FunctionUtils.toDifferentiable(f, gradient);
365
366 DSFactory factory = new DSFactory(1, 1);
367 try {
368 DerivativeStructure dsT = factory.variable(0, 0.0);
369 mdf.value(new DerivativeStructure[] { dsT.sin(), dsT.cos() });
370 Assert.fail("an exception should have been thrown");
371 } catch (MathIllegalArgumentException e) {
372 Assert.assertEquals(3, ((Integer) e.getParts()[0]).intValue());
373 Assert.assertEquals(2, ((Integer) e.getParts()[1]).intValue());
374 }
375 }
376
377 @Test
378 public void testDerivativeUnivariate() {
379
380 final UnivariateDifferentiableFunction f = new UnivariateDifferentiableFunction() {
381
382 @Override
383 public double value(double x) {
384 return x * x;
385 }
386
387 @Override
388 public <T extends Derivative<T>> T value(T x) {
389 return x.square();
390 }
391
392 };
393
394 final UnivariateFunction f0 = FunctionUtils.derivative(f, 0);
395 final UnivariateFunction f1 = FunctionUtils.derivative(f, 1);
396 final UnivariateFunction f2 = FunctionUtils.derivative(f, 2);
397
398 for (double t = -1.0; t < 1; t += 0.01) {
399 Assert.assertEquals(t * t, f0.value(t), 1.0e-15);
400 Assert.assertEquals(2 * t, f1.value(t), 1.0e-15);
401 Assert.assertEquals(2, f2.value(t), 1.0e-15);
402 }
403
404 }
405
406 @Test
407 public void testDerivativeMultivariate() {
408
409 final double a = 1.5;
410 final double b = 0.5;
411 final double c = 0.25;
412 final MultivariateDifferentiableFunction mdf = new MultivariateDifferentiableFunction() {
413
414 @Override
415 public double value(double[] point) {
416 return a * point[0] * point[0] + b * point[1] * point[1] + c * point[0] * point[1];
417 }
418
419 @Override
420 public DerivativeStructure value(DerivativeStructure[] point) {
421 DerivativeStructure x = point[0];
422 DerivativeStructure y = point[1];
423 DerivativeStructure x2 = x.square();
424 DerivativeStructure y2 = y.square();
425 DerivativeStructure xy = x.multiply(y);
426 return x2.multiply(a).add(y2.multiply(b)).add(xy.multiply(c));
427 }
428
429 };
430
431 final MultivariateFunction f = FunctionUtils.derivative(mdf, new int[] { 0, 0 });
432 final MultivariateFunction dfdx = FunctionUtils.derivative(mdf, new int[] { 1, 0 });
433 final MultivariateFunction dfdy = FunctionUtils.derivative(mdf, new int[] { 0, 1 });
434 final MultivariateFunction d2fdx2 = FunctionUtils.derivative(mdf, new int[] { 2, 0 });
435 final MultivariateFunction d2fdy2 = FunctionUtils.derivative(mdf, new int[] { 0, 2 });
436 final MultivariateFunction d2fdxdy = FunctionUtils.derivative(mdf, new int[] { 1, 1 });
437
438 for (double x = -1.0; x < 1; x += 0.01) {
439 for (double y = -1.0; y < 1; y += 0.01) {
440 Assert.assertEquals(a * x * x + b * y * y + c * x * y, f.value(new double[] { x, y }), 1.0e-15);
441 Assert.assertEquals(2 * a * x + c * y, dfdx.value(new double[] { x, y }), 1.0e-15);
442 Assert.assertEquals(2 * b * y + c * x, dfdy.value(new double[] { x, y }), 1.0e-15);
443 Assert.assertEquals(2 * a, d2fdx2.value(new double[] { x, y }), 1.0e-15);
444 Assert.assertEquals(2 * b, d2fdy2.value(new double[] { x, y }), 1.0e-15);
445 Assert.assertEquals(c, d2fdxdy.value(new double[] { x, y }), 1.0e-15);
446 }
447 }
448
449 }
450
451 }