1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * https://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18 /* 19 * This is not the original file distributed by the Apache Software Foundation 20 * It has been modified by the Hipparchus project 21 */ 22 23 package org.hipparchus.analysis.function; 24 25 import java.util.Arrays; 26 27 import org.hipparchus.analysis.ParametricUnivariateFunction; 28 import org.hipparchus.analysis.differentiation.Derivative; 29 import org.hipparchus.analysis.differentiation.UnivariateDifferentiableFunction; 30 import org.hipparchus.exception.LocalizedCoreFormats; 31 import org.hipparchus.exception.MathIllegalArgumentException; 32 import org.hipparchus.exception.NullArgumentException; 33 import org.hipparchus.util.FastMath; 34 import org.hipparchus.util.MathUtils; 35 import org.hipparchus.util.Precision; 36 37 /** 38 * <a href="http://en.wikipedia.org/wiki/Gaussian_function"> 39 * Gaussian</a> function. 40 * 41 */ 42 public class Gaussian implements UnivariateDifferentiableFunction { 43 /** Mean. */ 44 private final double mean; 45 /** Inverse of the standard deviation. */ 46 private final double is; 47 /** Inverse of twice the square of the standard deviation. */ 48 private final double i2s2; 49 /** Normalization factor. */ 50 private final double norm; 51 52 /** 53 * Gaussian with given normalization factor, mean and standard deviation. 54 * 55 * @param norm Normalization factor. 56 * @param mean Mean. 57 * @param sigma Standard deviation. 58 * @throws MathIllegalArgumentException if {@code sigma <= 0}. 59 */ 60 public Gaussian(double norm, 61 double mean, 62 double sigma) 63 throws MathIllegalArgumentException { 64 if (sigma <= 0) { 65 throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL_BOUND_EXCLUDED, 66 sigma, 0); 67 } 68 69 this.norm = norm; 70 this.mean = mean; 71 this.is = 1 / sigma; 72 this.i2s2 = 0.5 * is * is; 73 } 74 75 /** 76 * Normalized gaussian with given mean and standard deviation. 77 * 78 * @param mean Mean. 79 * @param sigma Standard deviation. 80 * @throws MathIllegalArgumentException if {@code sigma <= 0}. 81 */ 82 public Gaussian(double mean, 83 double sigma) 84 throws MathIllegalArgumentException { 85 this(1 / (sigma * FastMath.sqrt(2 * Math.PI)), mean, sigma); 86 } 87 88 /** 89 * Normalized gaussian with zero mean and unit standard deviation. 90 */ 91 public Gaussian() { 92 this(0, 1); 93 } 94 95 /** {@inheritDoc} */ 96 @Override 97 public double value(double x) { 98 return value(x - mean, norm, i2s2); 99 } 100 101 /** 102 * Parametric function where the input array contains the parameters of 103 * the Gaussian, ordered as follows: 104 * <ul> 105 * <li>Norm</li> 106 * <li>Mean</li> 107 * <li>Standard deviation</li> 108 * </ul> 109 */ 110 public static class Parametric implements ParametricUnivariateFunction { 111 112 /** Empty constructor. 113 * <p> 114 * This constructor is not strictly necessary, but it prevents spurious 115 * javadoc warnings with JDK 18 and later. 116 * </p> 117 * @since 3.0 118 */ 119 public Parametric() { // NOPMD - unnecessary constructor added intentionally to make javadoc happy 120 // nothing to do 121 } 122 123 /** 124 * Computes the value of the Gaussian at {@code x}. 125 * 126 * @param x Value for which the function must be computed. 127 * @param param Values of norm, mean and standard deviation. 128 * @return the value of the function. 129 * @throws NullArgumentException if {@code param} is {@code null}. 130 * @throws MathIllegalArgumentException if the size of {@code param} is 131 * not 3. 132 * @throws MathIllegalArgumentException if {@code param[2]} is negative. 133 */ 134 @Override 135 public double value(double x, double ... param) 136 throws MathIllegalArgumentException, NullArgumentException { 137 validateParameters(param); 138 139 final double diff = x - param[1]; 140 final double i2s2 = 1 / (2 * param[2] * param[2]); 141 return Gaussian.value(diff, param[0], i2s2); 142 } 143 144 /** 145 * Computes the value of the gradient at {@code x}. 146 * The components of the gradient vector are the partial 147 * derivatives of the function with respect to each of the 148 * <em>parameters</em> (norm, mean and standard deviation). 149 * 150 * @param x Value at which the gradient must be computed. 151 * @param param Values of norm, mean and standard deviation. 152 * @return the gradient vector at {@code x}. 153 * @throws NullArgumentException if {@code param} is {@code null}. 154 * @throws MathIllegalArgumentException if the size of {@code param} is 155 * not 3. 156 * @throws MathIllegalArgumentException if {@code param[2]} is negative. 157 */ 158 @Override 159 public double[] gradient(double x, double ... param) 160 throws MathIllegalArgumentException, NullArgumentException { 161 validateParameters(param); 162 163 final double norm = param[0]; 164 final double diff = x - param[1]; 165 final double sigma = param[2]; 166 final double i2s2 = 1 / (2 * sigma * sigma); 167 168 final double n = Gaussian.value(diff, 1, i2s2); 169 final double m = norm * n * 2 * i2s2 * diff; 170 final double s = m * diff / sigma; 171 172 return new double[] { n, m, s }; 173 } 174 175 /** 176 * Validates parameters to ensure they are appropriate for the evaluation of 177 * the {@link #value(double,double[])} and {@link #gradient(double,double[])} 178 * methods. 179 * 180 * @param param Values of norm, mean and standard deviation. 181 * @throws NullArgumentException if {@code param} is {@code null}. 182 * @throws MathIllegalArgumentException if the size of {@code param} is 183 * not 3. 184 * @throws MathIllegalArgumentException if {@code param[2]} is negative. 185 */ 186 private void validateParameters(double[] param) 187 throws MathIllegalArgumentException, NullArgumentException { 188 if (param == null) { 189 throw new NullArgumentException(); 190 } 191 MathUtils.checkDimension(param.length, 3); 192 if (param[2] <= 0) { 193 throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL_BOUND_EXCLUDED, 194 param[2], 0); 195 } 196 } 197 } 198 199 /** 200 * @param xMinusMean {@code x - mean}. 201 * @param norm Normalization factor. 202 * @param i2s2 Inverse of twice the square of the standard deviation. 203 * @return the value of the Gaussian at {@code x}. 204 */ 205 private static double value(double xMinusMean, 206 double norm, 207 double i2s2) { 208 return norm * FastMath.exp(-xMinusMean * xMinusMean * i2s2); 209 } 210 211 /** {@inheritDoc} 212 */ 213 @Override 214 public <T extends Derivative<T>> T value(T t) 215 throws MathIllegalArgumentException { 216 217 final double u = is * (t.getValue() - mean); 218 double[] f = new double[t.getOrder() + 1]; 219 220 // the nth order derivative of the Gaussian has the form: 221 // dn(g(x)/dxn = (norm / s^n) P_n(u) exp(-u^2/2) with u=(x-m)/s 222 // where P_n(u) is a degree n polynomial with same parity as n 223 // P_0(u) = 1, P_1(u) = -u, P_2(u) = u^2 - 1, P_3(u) = -u^3 + 3 u... 224 // the general recurrence relation for P_n is: 225 // P_n(u) = P_(n-1)'(u) - u P_(n-1)(u) 226 // as per polynomial parity, we can store coefficients of both P_(n-1) and P_n in the same array 227 final double[] p = new double[f.length]; 228 p[0] = 1; 229 final double u2 = u * u; 230 double coeff = norm * FastMath.exp(-0.5 * u2); 231 if (coeff <= Precision.SAFE_MIN) { 232 Arrays.fill(f, 0.0); 233 } else { 234 f[0] = coeff; 235 for (int n = 1; n < f.length; ++n) { 236 237 // update and evaluate polynomial P_n(x) 238 double v = 0; 239 p[n] = -p[n - 1]; 240 for (int k = n; k >= 0; k -= 2) { 241 v = v * u2 + p[k]; 242 if (k > 2) { 243 p[k - 2] = (k - 1) * p[k - 1] - p[k - 3]; 244 } else if (k == 2) { 245 p[0] = p[1]; 246 } 247 } 248 if ((n & 0x1) == 1) { 249 v *= u; 250 } 251 252 coeff *= is; 253 f[n] = coeff * v; 254 255 } 256 } 257 258 return t.compose(f); 259 260 } 261 262 }