FunctionUtils.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      https://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. /*
  18.  * This is not the original file distributed by the Apache Software Foundation
  19.  * It has been modified by the Hipparchus project
  20.  */

  21. package org.hipparchus.analysis;

  22. import org.hipparchus.analysis.differentiation.DSFactory;
  23. import org.hipparchus.analysis.differentiation.Derivative;
  24. import org.hipparchus.analysis.differentiation.DerivativeStructure;
  25. import org.hipparchus.analysis.differentiation.MultivariateDifferentiableFunction;
  26. import org.hipparchus.analysis.differentiation.UnivariateDifferentiableFunction;
  27. import org.hipparchus.analysis.function.Identity;
  28. import org.hipparchus.exception.LocalizedCoreFormats;
  29. import org.hipparchus.exception.MathIllegalArgumentException;
  30. import org.hipparchus.util.MathArrays;
  31. import org.hipparchus.util.MathUtils;

  32. /**
  33.  * Utilities for manipulating function objects.
  34.  *
  35.  */
  36. public class FunctionUtils {
  37.     /**
  38.      * Class only contains static methods.
  39.      */
  40.     private FunctionUtils() {}

  41.     /**
  42.      * Composes functions.
  43.      * <p>
  44.      * The functions in the argument list are composed sequentially, in the
  45.      * given order.  For example, compose(f1,f2,f3) acts like f1(f2(f3(x))).</p>
  46.      *
  47.      * @param f List of functions.
  48.      * @return the composite function.
  49.      */
  50.     public static UnivariateFunction compose(final UnivariateFunction ... f) {
  51.         return new UnivariateFunction() {
  52.             /** {@inheritDoc} */
  53.             @Override
  54.             public double value(double x) {
  55.                 double r = x;
  56.                 for (int i = f.length - 1; i >= 0; i--) {
  57.                     r = f[i].value(r);
  58.                 }
  59.                 return r;
  60.             }
  61.         };
  62.     }

  63.     /**
  64.      * Composes functions.
  65.      * <p>
  66.      * The functions in the argument list are composed sequentially, in the
  67.      * given order.  For example, compose(f1,f2,f3) acts like f1(f2(f3(x))).</p>
  68.      *
  69.      * @param f List of functions.
  70.      * @return the composite function.
  71.      */
  72.     public static UnivariateDifferentiableFunction compose(final UnivariateDifferentiableFunction ... f) {
  73.         return new UnivariateDifferentiableFunction() {

  74.             /** {@inheritDoc} */
  75.             @Override
  76.             public double value(final double t) {
  77.                 double r = t;
  78.                 for (int i = f.length - 1; i >= 0; i--) {
  79.                     r = f[i].value(r);
  80.                 }
  81.                 return r;
  82.             }

  83.             /** {@inheritDoc} */
  84.             @Override
  85.             public <T extends Derivative<T>> T value(final T t) {
  86.                 T r = t;
  87.                 for (int i = f.length - 1; i >= 0; i--) {
  88.                     r = f[i].value(r);
  89.                 }
  90.                 return r;
  91.             }

  92.         };
  93.     }

  94.     /**
  95.      * Adds functions.
  96.      *
  97.      * @param f List of functions.
  98.      * @return a function that computes the sum of the functions.
  99.      */
  100.     public static UnivariateFunction add(final UnivariateFunction ... f) {
  101.         return new UnivariateFunction() {
  102.             /** {@inheritDoc} */
  103.             @Override
  104.             public double value(double x) {
  105.                 double r = f[0].value(x);
  106.                 for (int i = 1; i < f.length; i++) {
  107.                     r += f[i].value(x);
  108.                 }
  109.                 return r;
  110.             }
  111.         };
  112.     }

  113.     /**
  114.      * Adds functions.
  115.      *
  116.      * @param f List of functions.
  117.      * @return a function that computes the sum of the functions.
  118.      */
  119.     public static UnivariateDifferentiableFunction add(final UnivariateDifferentiableFunction ... f) {
  120.         return new UnivariateDifferentiableFunction() {

  121.             /** {@inheritDoc} */
  122.             @Override
  123.             public double value(final double t) {
  124.                 double r = f[0].value(t);
  125.                 for (int i = 1; i < f.length; i++) {
  126.                     r += f[i].value(t);
  127.                 }
  128.                 return r;
  129.             }

  130.             /** {@inheritDoc}
  131.              * @throws MathIllegalArgumentException if functions are not consistent with each other
  132.              */
  133.             @Override
  134.             public <T extends Derivative<T>> T value(final T t)
  135.                 throws MathIllegalArgumentException {
  136.                 T r = f[0].value(t);
  137.                 for (int i = 1; i < f.length; i++) {
  138.                     r = r.add(f[i].value(t));
  139.                 }
  140.                 return r;
  141.             }

  142.         };
  143.     }

  144.     /**
  145.      * Multiplies functions.
  146.      *
  147.      * @param f List of functions.
  148.      * @return a function that computes the product of the functions.
  149.      */
  150.     public static UnivariateFunction multiply(final UnivariateFunction ... f) {
  151.         return new UnivariateFunction() {
  152.             /** {@inheritDoc} */
  153.             @Override
  154.             public double value(double x) {
  155.                 double r = f[0].value(x);
  156.                 for (int i = 1; i < f.length; i++) {
  157.                     r *= f[i].value(x);
  158.                 }
  159.                 return r;
  160.             }
  161.         };
  162.     }

  163.     /**
  164.      * Multiplies functions.
  165.      *
  166.      * @param f List of functions.
  167.      * @return a function that computes the product of the functions.
  168.      */
  169.     public static UnivariateDifferentiableFunction multiply(final UnivariateDifferentiableFunction ... f) {
  170.         return new UnivariateDifferentiableFunction() {

  171.             /** {@inheritDoc} */
  172.             @Override
  173.             public double value(final double t) {
  174.                 double r = f[0].value(t);
  175.                 for (int i = 1; i < f.length; i++) {
  176.                     r  *= f[i].value(t);
  177.                 }
  178.                 return r;
  179.             }

  180.             /** {@inheritDoc} */
  181.             @Override
  182.             public <T extends Derivative<T>> T value(final T t) {
  183.                 T r = f[0].value(t);
  184.                 for (int i = 1; i < f.length; i++) {
  185.                     r = r.multiply(f[i].value(t));
  186.                 }
  187.                 return r;
  188.             }

  189.         };
  190.     }

  191.     /**
  192.      * Returns the univariate function
  193.      * {@code h(x) = combiner(f(x), g(x)).}
  194.      *
  195.      * @param combiner Combiner function.
  196.      * @param f Function.
  197.      * @param g Function.
  198.      * @return the composite function.
  199.      */
  200.     public static UnivariateFunction combine(final BivariateFunction combiner,
  201.                                              final UnivariateFunction f,
  202.                                              final UnivariateFunction g) {
  203.         return new UnivariateFunction() {
  204.             /** {@inheritDoc} */
  205.             @Override
  206.             public double value(double x) {
  207.                 return combiner.value(f.value(x), g.value(x));
  208.             }
  209.         };
  210.     }

  211.     /**
  212.      * Returns a MultivariateFunction h(x[]) defined by <pre> <code>
  213.      * h(x[]) = combiner(...combiner(combiner(initialValue,f(x[0])),f(x[1]))...),f(x[x.length-1]))
  214.      * </code></pre>
  215.      *
  216.      * @param combiner Combiner function.
  217.      * @param f Function.
  218.      * @param initialValue Initial value.
  219.      * @return a collector function.
  220.      */
  221.     public static MultivariateFunction collector(final BivariateFunction combiner,
  222.                                                  final UnivariateFunction f,
  223.                                                  final double initialValue) {
  224.         return new MultivariateFunction() {
  225.             /** {@inheritDoc} */
  226.             @Override
  227.             public double value(double[] point) {
  228.                 double result = combiner.value(initialValue, f.value(point[0]));
  229.                 for (int i = 1; i < point.length; i++) {
  230.                     result = combiner.value(result, f.value(point[i]));
  231.                 }
  232.                 return result;
  233.             }
  234.         };
  235.     }

  236.     /**
  237.      * Returns a MultivariateFunction h(x[]) defined by <pre> <code>
  238.      * h(x[]) = combiner(...combiner(combiner(initialValue,x[0]),x[1])...),x[x.length-1])
  239.      * </code></pre>
  240.      *
  241.      * @param combiner Combiner function.
  242.      * @param initialValue Initial value.
  243.      * @return a collector function.
  244.      */
  245.     public static MultivariateFunction collector(final BivariateFunction combiner,
  246.                                                  final double initialValue) {
  247.         return collector(combiner, new Identity(), initialValue);
  248.     }

  249.     /**
  250.      * Creates a unary function by fixing the first argument of a binary function.
  251.      *
  252.      * @param f Binary function.
  253.      * @param fixed value to which the first argument of {@code f} is set.
  254.      * @return the unary function h(x) = f(fixed, x)
  255.      */
  256.     public static UnivariateFunction fix1stArgument(final BivariateFunction f,
  257.                                                     final double fixed) {
  258.         return new UnivariateFunction() {
  259.             /** {@inheritDoc} */
  260.             @Override
  261.             public double value(double x) {
  262.                 return f.value(fixed, x);
  263.             }
  264.         };
  265.     }
  266.     /**
  267.      * Creates a unary function by fixing the second argument of a binary function.
  268.      *
  269.      * @param f Binary function.
  270.      * @param fixed value to which the second argument of {@code f} is set.
  271.      * @return the unary function h(x) = f(x, fixed)
  272.      */
  273.     public static UnivariateFunction fix2ndArgument(final BivariateFunction f,
  274.                                                     final double fixed) {
  275.         return new UnivariateFunction() {
  276.             /** {@inheritDoc} */
  277.             @Override
  278.             public double value(double x) {
  279.                 return f.value(x, fixed);
  280.             }
  281.         };
  282.     }

  283.     /**
  284.      * Samples the specified univariate real function on the specified interval.
  285.      * <p>
  286.      * The interval is divided equally into {@code n} sections and sample points
  287.      * are taken from {@code min} to {@code max - (max - min) / n}; therefore
  288.      * {@code f} is not sampled at the upper bound {@code max}.</p>
  289.      *
  290.      * @param f Function to be sampled
  291.      * @param min Lower bound of the interval (included).
  292.      * @param max Upper bound of the interval (excluded).
  293.      * @param n Number of sample points.
  294.      * @return the array of samples.
  295.      * @throws MathIllegalArgumentException if the lower bound {@code min} is
  296.      * greater than, or equal to the upper bound {@code max}.
  297.      * @throws MathIllegalArgumentException if the number of sample points
  298.      * {@code n} is negative.
  299.      */
  300.     public static double[] sample(UnivariateFunction f, double min, double max, int n)
  301.        throws MathIllegalArgumentException {

  302.         if (n <= 0) {
  303.             throw new MathIllegalArgumentException(
  304.                     LocalizedCoreFormats.NOT_POSITIVE_NUMBER_OF_SAMPLES,
  305.                     n);
  306.         }
  307.         if (min >= max) {
  308.             throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_LARGE_BOUND_EXCLUDED,
  309.                                                    min, max);
  310.         }

  311.         final double[] s = new double[n];
  312.         final double h = (max - min) / n;
  313.         for (int i = 0; i < n; i++) {
  314.             s[i] = f.value(min + i * h);
  315.         }
  316.         return s;
  317.     }

  318.     /** Convert regular functions to {@link UnivariateDifferentiableFunction}.
  319.      * <p>
  320.      * This method handle the case with one free parameter and several derivatives.
  321.      * For the case with several free parameters and only first order derivatives,
  322.      * see {@link #toDifferentiable(MultivariateFunction, MultivariateVectorFunction)}.
  323.      * There are no direct support for intermediate cases, with several free parameters
  324.      * and order 2 or more derivatives, as is would be difficult to specify all the
  325.      * cross derivatives.
  326.      * </p>
  327.      * <p>
  328.      * Note that the derivatives are expected to be computed only with respect to the
  329.      * raw parameter x of the base function, i.e. they are df/dx, df<sup>2</sup>/dx<sup>2</sup>, ...
  330.      * Even if the built function is later used in a composition like f(sin(t)), the provided
  331.      * derivatives should <em>not</em> apply the composition with sine and its derivatives by
  332.      * themselves. The composition will be done automatically here and the result will properly
  333.      * contain f(sin(t)), df(sin(t))/dt, df<sup>2</sup>(sin(t))/dt<sup>2</sup> despite the
  334.      * provided derivatives functions know nothing about the sine function.
  335.      * </p>
  336.      * @param f base function f(x)
  337.      * @param derivatives derivatives of the base function, in increasing differentiation order
  338.      * @return a differentiable function with value and all specified derivatives
  339.      * @see #toDifferentiable(MultivariateFunction, MultivariateVectorFunction)
  340.      * @see #derivative(UnivariateDifferentiableFunction, int)
  341.      */
  342.     public static UnivariateDifferentiableFunction toDifferentiable(final UnivariateFunction f,
  343.                                                                     final UnivariateFunction ... derivatives) {

  344.         return new UnivariateDifferentiableFunction() {

  345.             /** {@inheritDoc} */
  346.             @Override
  347.             public double value(final double x) {
  348.                 return f.value(x);
  349.             }

  350.             /** {@inheritDoc} */
  351.             @Override
  352.             public <T extends Derivative<T>> T value(final T x) {
  353.                 if (x.getOrder() > derivatives.length) {
  354.                     throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_LARGE,
  355.                                                            x.getOrder(), derivatives.length);
  356.                 }
  357.                 final double[] packed = new double[x.getOrder() + 1];
  358.                 packed[0] = f.value(x.getValue());
  359.                 for (int i = 0; i < x.getOrder(); ++i) {
  360.                     packed[i + 1] = derivatives[i].value(x.getValue());
  361.                 }
  362.                 return x.compose(packed);
  363.             }

  364.         };

  365.     }

  366.     /** Convert regular functions to {@link MultivariateDifferentiableFunction}.
  367.      * <p>
  368.      * This method handle the case with several free parameters and only first order derivatives.
  369.      * For the case with one free parameter and several derivatives,
  370.      * see {@link #toDifferentiable(UnivariateFunction, UnivariateFunction...)}.
  371.      * There are no direct support for intermediate cases, with several free parameters
  372.      * and order 2 or more derivatives, as is would be difficult to specify all the
  373.      * cross derivatives.
  374.      * </p>
  375.      * <p>
  376.      * Note that the gradient is expected to be computed only with respect to the
  377.      * raw parameter x of the base function, i.e. it is df/dx<sub>1</sub>, df/dx<sub>2</sub>, ...
  378.      * Even if the built function is later used in a composition like f(sin(t), cos(t)), the provided
  379.      * gradient should <em>not</em> apply the composition with sine or cosine and their derivative by
  380.      * itself. The composition will be done automatically here and the result will properly
  381.      * contain f(sin(t), cos(t)), df(sin(t), cos(t))/dt despite the provided derivatives functions
  382.      * know nothing about the sine or cosine functions.
  383.      * </p>
  384.      * @param f base function f(x)
  385.      * @param gradient gradient of the base function
  386.      * @return a differentiable function with value and gradient
  387.      * @see #toDifferentiable(UnivariateFunction, UnivariateFunction...)
  388.      * @see #derivative(MultivariateDifferentiableFunction, int[])
  389.      */
  390.     public static MultivariateDifferentiableFunction toDifferentiable(final MultivariateFunction f,
  391.                                                                       final MultivariateVectorFunction gradient) {

  392.         return new MultivariateDifferentiableFunction() {

  393.             /** {@inheritDoc} */
  394.             @Override
  395.             public double value(final double[] point) {
  396.                 return f.value(point);
  397.             }

  398.             /** {@inheritDoc} */
  399.             @Override
  400.             public DerivativeStructure value(final DerivativeStructure[] point) {

  401.                 // set up the input parameters
  402.                 final double[] dPoint = new double[point.length];
  403.                 for (int i = 0; i < point.length; ++i) {
  404.                     dPoint[i] = point[i].getValue();
  405.                     if (point[i].getOrder() > 1) {
  406.                         throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_LARGE,
  407.                                                                point[i].getOrder(), 1);
  408.                     }
  409.                 }

  410.                 // evaluate regular functions
  411.                 final double    v = f.value(dPoint);
  412.                 final double[] dv = gradient.value(dPoint);
  413.                 MathUtils.checkDimension(dv.length, point.length);

  414.                 // build the combined derivative
  415.                 final int parameters = point[0].getFreeParameters();
  416.                 final double[] partials = new double[point.length];
  417.                 final double[] packed = new double[parameters + 1];
  418.                 packed[0] = v;
  419.                 final int[] orders = new int[parameters];
  420.                 for (int i = 0; i < parameters; ++i) {

  421.                     // we differentiate once with respect to parameter i
  422.                     orders[i] = 1;
  423.                     for (int j = 0; j < point.length; ++j) {
  424.                         partials[j] = point[j].getPartialDerivative(orders);
  425.                     }
  426.                     orders[i] = 0;

  427.                     // compose partial derivatives
  428.                     packed[i + 1] = MathArrays.linearCombination(dv, partials);

  429.                 }

  430.                 return point[0].getFactory().build(packed);

  431.             }

  432.         };

  433.     }

  434.     /** Convert an {@link UnivariateDifferentiableFunction} to an
  435.      * {@link UnivariateFunction} computing n<sup>th</sup> order derivative.
  436.      * <p>
  437.      * This converter is only a convenience method. Beware computing only one derivative does
  438.      * not save any computation as the original function will really be called under the hood.
  439.      * The derivative will be extracted from the full {@link DerivativeStructure} result.
  440.      * </p>
  441.      * @param f original function, with value and all its derivatives
  442.      * @param order of the derivative to extract
  443.      * @return function computing the derivative at required order
  444.      * @see #derivative(MultivariateDifferentiableFunction, int[])
  445.      * @see #toDifferentiable(UnivariateFunction, UnivariateFunction...)
  446.      */
  447.     public static UnivariateFunction derivative(final UnivariateDifferentiableFunction f, final int order) {

  448.         final DSFactory factory = new DSFactory(1, order);

  449.         return new UnivariateFunction() {

  450.             /** {@inheritDoc} */
  451.             @Override
  452.             public double value(final double x) {
  453.                 final DerivativeStructure dsX = factory.variable(0, x);
  454.                 return f.value(dsX).getPartialDerivative(order);
  455.             }

  456.         };
  457.     }

  458.     /** Convert an {@link MultivariateDifferentiableFunction} to an
  459.      * {@link MultivariateFunction} computing n<sup>th</sup> order derivative.
  460.      * <p>
  461.      * This converter is only a convenience method. Beware computing only one derivative does
  462.      * not save any computation as the original function will really be called under the hood.
  463.      * The derivative will be extracted from the full {@link DerivativeStructure} result.
  464.      * </p>
  465.      * @param f original function, with value and all its derivatives
  466.      * @param orders of the derivative to extract, for each free parameters
  467.      * @return function computing the derivative at required order
  468.      * @see #derivative(UnivariateDifferentiableFunction, int)
  469.      * @see #toDifferentiable(MultivariateFunction, MultivariateVectorFunction)
  470.      */
  471.     public static MultivariateFunction derivative(final MultivariateDifferentiableFunction f, final int[] orders) {

  472.         // the maximum differentiation order is the sum of all orders
  473.         int sum = 0;
  474.         for (final int order : orders) {
  475.             sum += order;
  476.         }
  477.         final int sumOrders = sum;

  478.         return new MultivariateFunction() {

  479.             /** Factory used for building derivatives. */
  480.             private DSFactory factory;

  481.             /** {@inheritDoc} */
  482.             @Override
  483.             public double value(final double[] point) {

  484.                 if (factory == null || point.length != factory.getCompiler().getFreeParameters()) {
  485.                     // rebuild the factory in case of mismatch
  486.                     factory = new DSFactory(point.length, sumOrders);
  487.                 }

  488.                 // set up the input parameters
  489.                 final DerivativeStructure[] dsPoint = new DerivativeStructure[point.length];
  490.                 for (int i = 0; i < point.length; ++i) {
  491.                     dsPoint[i] = factory.variable(i, point[i]);
  492.                 }

  493.                 return f.value(dsPoint).getPartialDerivative(orders);

  494.             }

  495.         };
  496.     }

  497. }