GLSMultipleLinearRegression.java
- /*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- /*
- * This is not the original file distributed by the Apache Software Foundation
- * It has been modified by the Hipparchus project
- */
- package org.hipparchus.stat.regression;
- import org.hipparchus.linear.Array2DRowRealMatrix;
- import org.hipparchus.linear.LUDecomposition;
- import org.hipparchus.linear.RealMatrix;
- import org.hipparchus.linear.RealVector;
- /**
- * The GLS implementation of multiple linear regression.
- *
- * GLS assumes a general covariance matrix Omega of the error
- * <pre>
- * u ~ N(0, Omega)
- * </pre>
- *
- * Estimated by GLS,
- * <pre>
- * b=(X' Omega^-1 X)^-1X'Omega^-1 y
- * </pre>
- * whose variance is
- * <pre>
- * Var(b)=(X' Omega^-1 X)^-1
- * </pre>
- */
- public class GLSMultipleLinearRegression extends AbstractMultipleLinearRegression {
- /** Covariance matrix. */
- private RealMatrix Omega;
- /** Inverse of covariance matrix. */
- private RealMatrix OmegaInverse;
- /** Empty constructor.
- * <p>
- * This constructor is not strictly necessary, but it prevents spurious
- * javadoc warnings with JDK 18 and later.
- * </p>
- * @since 3.0
- */
- public GLSMultipleLinearRegression() { // NOPMD - unnecessary constructor added intentionally to make javadoc happy
- // nothing to do
- }
- /** Replace sample data, overriding any previous sample.
- * @param y y values of the sample
- * @param x x values of the sample
- * @param covariance array representing the covariance matrix
- */
- public void newSampleData(double[] y, double[][] x, double[][] covariance) {
- validateSampleData(x, y);
- newYSampleData(y);
- newXSampleData(x);
- validateCovarianceData(x, covariance);
- newCovarianceData(covariance);
- }
- /**
- * Add the covariance data.
- *
- * @param omega the [n,n] array representing the covariance
- */
- protected void newCovarianceData(double[][] omega){
- this.Omega = new Array2DRowRealMatrix(omega);
- this.OmegaInverse = null;
- }
- /**
- * Get the inverse of the covariance.
- * <p>The inverse of the covariance matrix is lazily evaluated and cached.</p>
- * @return inverse of the covariance
- */
- protected RealMatrix getOmegaInverse() {
- if (OmegaInverse == null) {
- OmegaInverse = new LUDecomposition(Omega).getSolver().getInverse();
- }
- return OmegaInverse;
- }
- /**
- * Calculates beta by GLS.
- * <pre>
- * b=(X' Omega^-1 X)^-1X'Omega^-1 y
- * </pre>
- * @return beta
- */
- @Override
- protected RealVector calculateBeta() {
- RealMatrix OI = getOmegaInverse();
- RealMatrix XT = getX().transpose();
- RealMatrix XTOIX = XT.multiply(OI).multiply(getX());
- RealMatrix inverse = new LUDecomposition(XTOIX).getSolver().getInverse();
- return inverse.multiply(XT).multiply(OI).operate(getY());
- }
- /**
- * Calculates the variance on the beta.
- * <pre>
- * Var(b)=(X' Omega^-1 X)^-1
- * </pre>
- * @return The beta variance matrix
- */
- @Override
- protected RealMatrix calculateBetaVariance() {
- RealMatrix OI = getOmegaInverse();
- RealMatrix XTOIX = getX().transposeMultiply(OI).multiply(getX());
- return new LUDecomposition(XTOIX).getSolver().getInverse();
- }
- /**
- * Calculates the estimated variance of the error term using the formula
- * <pre>
- * Var(u) = Tr(u' Omega^-1 u)/(n-k)
- * </pre>
- * where n and k are the row and column dimensions of the design
- * matrix X.
- *
- * @return error variance
- */
- @Override
- protected double calculateErrorVariance() {
- RealVector residuals = calculateResiduals();
- double t = residuals.dotProduct(getOmegaInverse().operate(residuals));
- return t / (getX().getRowDimension() - getX().getColumnDimension());
- }
- }