AbstractMultipleLinearRegression.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      https://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. /*
  18.  * This is not the original file distributed by the Apache Software Foundation
  19.  * It has been modified by the Hipparchus project
  20.  */
  21. package org.hipparchus.stat.regression;

  22. import org.hipparchus.exception.LocalizedCoreFormats;
  23. import org.hipparchus.exception.MathIllegalArgumentException;
  24. import org.hipparchus.exception.NullArgumentException;
  25. import org.hipparchus.linear.Array2DRowRealMatrix;
  26. import org.hipparchus.linear.ArrayRealVector;
  27. import org.hipparchus.linear.RealMatrix;
  28. import org.hipparchus.linear.RealVector;
  29. import org.hipparchus.stat.LocalizedStatFormats;
  30. import org.hipparchus.stat.descriptive.moment.Variance;
  31. import org.hipparchus.util.FastMath;
  32. import org.hipparchus.util.MathUtils;

  33. /**
  34.  * Abstract base class for implementations of MultipleLinearRegression.
  35.  */
  36. public abstract class AbstractMultipleLinearRegression implements
  37.         MultipleLinearRegression {

  38.     /** X sample data. */
  39.     private RealMatrix xMatrix;

  40.     /** Y sample data. */
  41.     private RealVector yVector;

  42.     /** Whether or not the regression model includes an intercept.  True means no intercept. */
  43.     private boolean noIntercept;

  44.     /** Empty constructor.
  45.      * <p>
  46.      * This constructor is not strictly necessary, but it prevents spurious
  47.      * javadoc warnings with JDK 18 and later.
  48.      * </p>
  49.      * @since 3.0
  50.      */
  51.     public AbstractMultipleLinearRegression() { // NOPMD - unnecessary constructor added intentionally to make javadoc happy
  52.         // nothing to do
  53.     }

  54.     /** Get the X sample data.
  55.      * @return the X sample data.
  56.      */
  57.     protected RealMatrix getX() {
  58.         return xMatrix;
  59.     }

  60.     /** Get the Y sample data.
  61.      * @return the Y sample data.
  62.      */
  63.     protected RealVector getY() {
  64.         return yVector;
  65.     }

  66.     /** Chekc if the model has no intercept term.
  67.      * @return true if the model has no intercept term; false otherwise
  68.      */
  69.     public boolean isNoIntercept() {
  70.         return noIntercept;
  71.     }

  72.     /** Set intercept flag.
  73.      * @param noIntercept true means the model is to be estimated without an intercept term
  74.      */
  75.     public void setNoIntercept(boolean noIntercept) {
  76.         this.noIntercept = noIntercept;
  77.     }

  78.     /**
  79.      * <p>Loads model x and y sample data from a flat input array, overriding any previous sample.
  80.      * </p>
  81.      * <p>Assumes that rows are concatenated with y values first in each row.  For example, an input
  82.      * <code>data</code> array containing the sequence of values (1, 2, 3, 4, 5, 6, 7, 8, 9) with
  83.      * <code>nobs = 3</code> and <code>nvars = 2</code> creates a regression dataset with two
  84.      * independent variables, as below:
  85.      * </p>
  86.      * <pre>
  87.      *   y   x[0]  x[1]
  88.      *   --------------
  89.      *   1     2     3
  90.      *   4     5     6
  91.      *   7     8     9
  92.      * </pre>
  93.      * <p>Note that there is no need to add an initial unitary column (column of 1's) when
  94.      * specifying a model including an intercept term.  If {@link #isNoIntercept()} is <code>true</code>,
  95.      * the X matrix will be created without an initial column of "1"s; otherwise this column will
  96.      * be added.
  97.      * </p>
  98.      * <p>Throws IllegalArgumentException if any of the following preconditions fail:</p>
  99.      * <ul><li><code>data</code> cannot be null</li>
  100.      * <li><code>data.length = nobs * (nvars + 1)</code></li>
  101.      * <li><code>nobs &gt; nvars</code></li></ul>
  102.      *
  103.      * @param data input data array
  104.      * @param nobs number of observations (rows)
  105.      * @param nvars number of independent variables (columns, not counting y)
  106.      * @throws NullArgumentException if the data array is null
  107.      * @throws MathIllegalArgumentException if the length of the data array is not equal
  108.      * to <code>nobs * (nvars + 1)</code>
  109.      * @throws MathIllegalArgumentException if <code>nobs</code> is less than
  110.      * <code>nvars + 1</code>
  111.      */
  112.     public void newSampleData(double[] data, int nobs, int nvars) {
  113.         MathUtils.checkNotNull(data, LocalizedCoreFormats.INPUT_ARRAY);
  114.         MathUtils.checkDimension(data.length, nobs * (nvars + 1));
  115.         if (nobs <= nvars) {
  116.             throw new MathIllegalArgumentException(LocalizedCoreFormats.INSUFFICIENT_OBSERVED_POINTS_IN_SAMPLE,
  117.                                                    nobs, nvars + 1);
  118.         }
  119.         double[] y = new double[nobs];
  120.         final int cols = noIntercept ? nvars: nvars + 1;
  121.         double[][] x = new double[nobs][cols];
  122.         int pointer = 0;
  123.         for (int i = 0; i < nobs; i++) {
  124.             y[i] = data[pointer++];
  125.             if (!noIntercept) {
  126.                 x[i][0] = 1.0d;
  127.             }
  128.             for (int j = noIntercept ? 0 : 1; j < cols; j++) {
  129.                 x[i][j] = data[pointer++];
  130.             }
  131.         }
  132.         this.xMatrix = new Array2DRowRealMatrix(x);
  133.         this.yVector = new ArrayRealVector(y);
  134.     }

  135.     /**
  136.      * Loads new y sample data, overriding any previous data.
  137.      *
  138.      * @param y the array representing the y sample
  139.      * @throws NullArgumentException if y is null
  140.      * @throws MathIllegalArgumentException if y is empty
  141.      */
  142.     protected void newYSampleData(double[] y) {
  143.         if (y == null) {
  144.             throw new NullArgumentException();
  145.         }
  146.         if (y.length == 0) {
  147.             throw new MathIllegalArgumentException(LocalizedCoreFormats.NO_DATA);
  148.         }
  149.         this.yVector = new ArrayRealVector(y);
  150.     }

  151.     /**
  152.      * <p>Loads new x sample data, overriding any previous data.
  153.      * </p>
  154.      * <p>
  155.      * The input <code>x</code> array should have one row for each sample
  156.      * observation, with columns corresponding to independent variables.
  157.      * For example, if
  158.      * </p>
  159.      * <pre>
  160.      * <code> x = new double[][] {{1, 2}, {3, 4}, {5, 6}} </code></pre>
  161.      * <p>
  162.      * then <code>setXSampleData(x) </code> results in a model with two independent
  163.      * variables and 3 observations:
  164.      * </p>
  165.      * <pre>
  166.      *   x[0]  x[1]
  167.      *   ----------
  168.      *     1    2
  169.      *     3    4
  170.      *     5    6
  171.      * </pre>
  172.      * <p>Note that there is no need to add an initial unitary column (column of 1's) when
  173.      * specifying a model including an intercept term.
  174.      * </p>
  175.      * @param x the rectangular array representing the x sample
  176.      * @throws NullArgumentException if x is null
  177.      * @throws MathIllegalArgumentException if x is empty
  178.      * @throws MathIllegalArgumentException if x is not rectangular
  179.      */
  180.     protected void newXSampleData(double[][] x) {
  181.         if (x == null) {
  182.             throw new NullArgumentException();
  183.         }
  184.         if (x.length == 0) {
  185.             throw new MathIllegalArgumentException(LocalizedCoreFormats.NO_DATA);
  186.         }
  187.         if (noIntercept) {
  188.             this.xMatrix = new Array2DRowRealMatrix(x, true);
  189.         } else { // Augment design matrix with initial unitary column
  190.             final int nVars = x[0].length;
  191.             final double[][] xAug = new double[x.length][nVars + 1];
  192.             for (int i = 0; i < x.length; i++) {
  193.                 MathUtils.checkDimension(x[i].length, nVars);
  194.                 xAug[i][0] = 1.0d;
  195.                 System.arraycopy(x[i], 0, xAug[i], 1, nVars);
  196.             }
  197.             this.xMatrix = new Array2DRowRealMatrix(xAug, false);
  198.         }
  199.     }

  200.     /**
  201.      * Validates sample data.
  202.      * <p>Checks that</p>
  203.      * <ul><li>Neither x nor y is null or empty;</li>
  204.      * <li>The length (i.e. number of rows) of x equals the length of y</li>
  205.      * <li>x has at least one more row than it has columns (i.e. there is
  206.      * sufficient data to estimate regression coefficients for each of the
  207.      * columns in x plus an intercept.</li>
  208.      * </ul>
  209.      *
  210.      * @param x the [n,k] array representing the x data
  211.      * @param y the [n,1] array representing the y data
  212.      * @throws NullArgumentException if {@code x} or {@code y} is null
  213.      * @throws MathIllegalArgumentException if {@code x} and {@code y} do not
  214.      * have the same length
  215.      * @throws MathIllegalArgumentException if {@code x} or {@code y} are zero-length
  216.      * @throws MathIllegalArgumentException if the number of rows of {@code x}
  217.      * is not larger than the number of columns + 1 if the model has an intercept;
  218.      * or the number of columns if there is no intercept term
  219.      */
  220.     protected void validateSampleData(double[][] x, double[] y) throws MathIllegalArgumentException {
  221.         if ((x == null) || (y == null)) {
  222.             throw new NullArgumentException();
  223.         }
  224.         MathUtils.checkDimension(x.length, y.length);
  225.         if (x.length == 0) {  // Must be no y data either
  226.             throw new MathIllegalArgumentException(LocalizedCoreFormats.NO_DATA);
  227.         }
  228.         if (x[0].length + (noIntercept ? 0 : 1) > x.length) {
  229.             throw new MathIllegalArgumentException(
  230.                     LocalizedStatFormats.NOT_ENOUGH_DATA_FOR_NUMBER_OF_PREDICTORS,
  231.                     x.length, x[0].length);
  232.         }
  233.     }

  234.     /**
  235.      * Validates that the x data and covariance matrix have the same
  236.      * number of rows and that the covariance matrix is square.
  237.      *
  238.      * @param x the [n,k] array representing the x sample
  239.      * @param covariance the [n,n] array representing the covariance matrix
  240.      * @throws MathIllegalArgumentException if the number of rows in x is not equal
  241.      * to the number of rows in covariance
  242.      * @throws MathIllegalArgumentException if the covariance matrix is not square
  243.      */
  244.     protected void validateCovarianceData(double[][] x, double[][] covariance) {
  245.         MathUtils.checkDimension(x.length, covariance.length);
  246.         if (covariance.length > 0 && covariance.length != covariance[0].length) {
  247.             throw new MathIllegalArgumentException(LocalizedCoreFormats.NON_SQUARE_MATRIX,
  248.                                                    covariance.length, covariance[0].length);
  249.         }
  250.     }

  251.     /**
  252.      * {@inheritDoc}
  253.      */
  254.     @Override
  255.     public double[] estimateRegressionParameters() {
  256.         RealVector b = calculateBeta();
  257.         return b.toArray();
  258.     }

  259.     /**
  260.      * {@inheritDoc}
  261.      */
  262.     @Override
  263.     public double[] estimateResiduals() {
  264.         RealVector b = calculateBeta();
  265.         RealVector e = yVector.subtract(xMatrix.operate(b));
  266.         return e.toArray();
  267.     }

  268.     /**
  269.      * {@inheritDoc}
  270.      */
  271.     @Override
  272.     public double[][] estimateRegressionParametersVariance() {
  273.         return calculateBetaVariance().getData();
  274.     }

  275.     /**
  276.      * {@inheritDoc}
  277.      */
  278.     @Override
  279.     public double[] estimateRegressionParametersStandardErrors() {
  280.         double[][] betaVariance = estimateRegressionParametersVariance();
  281.         double sigma = calculateErrorVariance();
  282.         int length = betaVariance[0].length;
  283.         double[] result = new double[length];
  284.         for (int i = 0; i < length; i++) {
  285.             result[i] = FastMath.sqrt(sigma * betaVariance[i][i]);
  286.         }
  287.         return result;
  288.     }

  289.     /**
  290.      * {@inheritDoc}
  291.      */
  292.     @Override
  293.     public double estimateRegressandVariance() {
  294.         return calculateYVariance();
  295.     }

  296.     /**
  297.      * Estimates the variance of the error.
  298.      *
  299.      * @return estimate of the error variance
  300.      */
  301.     public double estimateErrorVariance() {
  302.         return calculateErrorVariance();

  303.     }

  304.     /**
  305.      * Estimates the standard error of the regression.
  306.      *
  307.      * @return regression standard error
  308.      */
  309.     public double estimateRegressionStandardError() {
  310.         return FastMath.sqrt(estimateErrorVariance());
  311.     }

  312.     /**
  313.      * Calculates the beta of multiple linear regression in matrix notation.
  314.      *
  315.      * @return beta
  316.      */
  317.     protected abstract RealVector calculateBeta();

  318.     /**
  319.      * Calculates the beta variance of multiple linear regression in matrix
  320.      * notation.
  321.      *
  322.      * @return beta variance
  323.      */
  324.     protected abstract RealMatrix calculateBetaVariance();


  325.     /**
  326.      * Calculates the variance of the y values.
  327.      *
  328.      * @return Y variance
  329.      */
  330.     protected double calculateYVariance() {
  331.         return new Variance().evaluate(yVector.toArray());
  332.     }

  333.     /**
  334.      * <p>Calculates the variance of the error term.</p>
  335.      * Uses the formula <pre>
  336.      * var(u) = u &middot; u / (n - k)
  337.      * </pre>
  338.      * where n and k are the row and column dimensions of the design
  339.      * matrix X.
  340.      *
  341.      * @return error variance estimate
  342.      */
  343.     protected double calculateErrorVariance() {
  344.         RealVector residuals = calculateResiduals();
  345.         return residuals.dotProduct(residuals) /
  346.                (xMatrix.getRowDimension() - xMatrix.getColumnDimension());
  347.     }

  348.     /**
  349.      * Calculates the residuals of multiple linear regression in matrix
  350.      * notation.
  351.      *
  352.      * <pre>
  353.      * u = y - X * b
  354.      * </pre>
  355.      *
  356.      * @return The residuals [n,1] matrix
  357.      */
  358.     protected RealVector calculateResiduals() {
  359.         RealVector b = calculateBeta();
  360.         return yVector.subtract(xMatrix.operate(b));
  361.     }

  362. }