1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * https://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18 /* 19 * This is not the original file distributed by the Apache Software Foundation 20 * It has been modified by the Hipparchus project 21 */ 22 package org.hipparchus.analysis.differentiation; 23 24 import org.hipparchus.analysis.MultivariateVectorFunction; 25 26 /** Class representing the gradient of a multivariate function. 27 * <p> 28 * The vectorial components of the function represent the derivatives 29 * with respect to each function parameters. 30 * </p> 31 */ 32 public class GradientFunction implements MultivariateVectorFunction { 33 34 /** Underlying real-valued function. */ 35 private final MultivariateDifferentiableFunction f; 36 37 /** Simple constructor. 38 * @param f underlying real-valued function 39 */ 40 public GradientFunction(final MultivariateDifferentiableFunction f) { 41 this.f = f; 42 } 43 44 /** {@inheritDoc} */ 45 @Override 46 public double[] value(double[] point) { 47 48 // set up parameters 49 final DSFactory factory = new DSFactory(point.length, 1); 50 final DerivativeStructure[] dsX = new DerivativeStructure[point.length]; 51 for (int i = 0; i < point.length; ++i) { 52 dsX[i] = factory.variable(i, point[i]); 53 } 54 55 // compute the derivatives 56 final DerivativeStructure dsY = f.value(dsX); 57 58 // extract the gradient 59 final double[] y = new double[point.length]; 60 final int[] orders = new int[point.length]; 61 for (int i = 0; i < point.length; ++i) { 62 orders[i] = 1; 63 y[i] = dsY.getPartialDerivative(orders); 64 orders[i] = 0; 65 } 66 67 return y; 68 69 } 70 71 }