PCA.java

  1. /*
  2.  * Licensed to the Hipparchus project under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The Hipparchus project licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      https://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.hipparchus.stat.projection;

  18. import org.hipparchus.exception.MathIllegalStateException;
  19. import org.hipparchus.linear.EigenDecompositionSymmetric;
  20. import org.hipparchus.linear.MatrixUtils;
  21. import org.hipparchus.linear.RealMatrix;
  22. import org.hipparchus.stat.LocalizedStatFormats;
  23. import org.hipparchus.stat.StatUtils;
  24. import org.hipparchus.stat.correlation.Covariance;
  25. import org.hipparchus.stat.descriptive.moment.StandardDeviation;

  26. /**
  27.  * Principal component analysis (PCA) is a statistical technique for reducing the dimensionality of a dataset.
  28.  * <a href="https://en.wikipedia.org/wiki/Principal_component_analysis">PCA</a> can be thought of as a
  29.  * projection or scaling of the data to reduce the number of dimensions but done in a way
  30.  * that preserves as much information as possible.
  31.  * @since 3.0
  32.  */
  33. public class PCA {
  34.     /**
  35.      * The number of components (reduced dimensions) for this projection.
  36.      */
  37.     private final int numC;

  38.     /**
  39.      * Whether to scale (standardize) the input data as well as center (normalize).
  40.      */
  41.     private final boolean scale;

  42.     /**
  43.      * Whether to correct for bias when standardizing. Ignored when only centering.
  44.      */
  45.     private final boolean biasCorrection;

  46.     /**
  47.      * The by column (feature) averages (means) from the fitted data.
  48.      */
  49.     private double[] center;

  50.     /**
  51.      * The by column (feature) standard deviates from the fitted data.
  52.      */
  53.     private double[] std;

  54.     /**
  55.      * The eigenValues (variance) of our projection model.
  56.      */
  57.     private double[] eigenValues;

  58.     /**
  59.      * The eigenVectors (components) of our projection model.
  60.      */
  61.     private RealMatrix principalComponents;

  62.     /**
  63.      * Utility class when scaling.
  64.      */
  65.     private final StandardDeviation sd;

  66.     /**
  67.      * Create a PCA with the ability to adjust scaling parameters.
  68.      *
  69.      * @param numC the number of components
  70.      * @param scale whether to also scale (correlation) rather than just center (covariance)
  71.      * @param biasCorrection whether to adjust for bias when scaling
  72.      */
  73.     public PCA(int numC, boolean scale, boolean biasCorrection) {
  74.         this.numC = numC;
  75.         this.scale = scale;
  76.         this.biasCorrection = biasCorrection;
  77.         sd = scale ? new StandardDeviation(biasCorrection) : null;
  78.     }

  79.     /**
  80.      * A default PCA will center but not scale.
  81.      *
  82.      * @param numC the number of components
  83.      */
  84.     public PCA(int numC) {
  85.         this(numC, false, true);
  86.     }

  87.     /** GEt number of components.
  88.      * @return the number of components
  89.      */
  90.     public int getNumComponents() {
  91.         return numC;
  92.     }

  93.     /** Check whether scaling (correlation) or no scaling (covariance) is used.
  94.      * @return whether scaling (correlation) or no scaling (covariance) is used
  95.      */
  96.     public boolean isScale() {
  97.         return scale;
  98.     }

  99.     /** Check whether scaling (correlation), if in use, adjusts for bias.
  100.      * @return whether scaling (correlation), if in use, adjusts for bias
  101.      */
  102.     public boolean isBiasCorrection() {
  103.         return biasCorrection;
  104.     }

  105.     /** Get principal component variances.
  106.      * @return the principal component variances, ordered from largest to smallest, which are the eigenvalues of the covariance or correlation matrix of the fitted data
  107.      */
  108.     public double[] getVariance() {
  109.         validateState("getVariance");
  110.         return eigenValues.clone();
  111.     }

  112.     /** Get by column center (or mean) of the fitted data.
  113.      * @return the by column center (or mean) of the fitted data
  114.      */
  115.     public double[] getCenter() {
  116.         validateState("getCenter");
  117.         return center.clone();
  118.     }

  119.     /**
  120.      * Returns the principal components of our projection model.
  121.      * These are the eigenvectors of our covariance/correlation matrix.
  122.      *
  123.      * @return the principal components
  124.      */
  125.     public double[][] getComponents() {
  126.         validateState("getComponents");
  127.         return principalComponents.getData();
  128.     }

  129.     /**
  130.      * Fit our model to the data and then transform it to the reduced dimensions.
  131.      *
  132.      * @param data the input data
  133.      * @return the fitted data
  134.      */
  135.     public double[][] fitAndTransform(double[][] data) {
  136.         center = null;
  137.         RealMatrix normalizedM = getNormalizedMatrix(data);
  138.         calculatePrincipalComponents(normalizedM);
  139.         return normalizedM.multiply(principalComponents).getData();
  140.     }

  141.     /**
  142.      * Transform the supplied data using our projection model.
  143.      *
  144.      * @param data the input data
  145.      * @return the fitted data
  146.      */
  147.     public double[][] transform(double[][] data) {
  148.         validateState("transform");
  149.         RealMatrix normalizedM = getNormalizedMatrix(data);
  150.         return normalizedM.multiply(principalComponents).getData();
  151.     }

  152.     /**
  153.      * Fit our model to the data, ready for subsequence transforms.
  154.      *
  155.      * @param data the input data
  156.      * @return this
  157.      */
  158.     public PCA fit(double[][] data) {
  159.         center = null;
  160.         RealMatrix normalized = getNormalizedMatrix(data);
  161.         calculatePrincipalComponents(normalized);
  162.         return this;
  163.     }

  164.     /** Check if the state allows an operation to be performed.
  165.      * @param from name of the operation
  166.      * @exception MathIllegalStateException if the state does not allows operation
  167.      */
  168.     private void validateState(String from) {
  169.         if (center == null) {
  170.             throw new MathIllegalStateException(LocalizedStatFormats.ILLEGAL_STATE_PCA, from);
  171.         }

  172.     }

  173.     /** Compute eigenvalues and principal components.
  174.      * <p>
  175.      * The results are stored in the instance itself
  176.      * <p>
  177.      * @param normalizedM normalized matrix
  178.      */
  179.     private void calculatePrincipalComponents(RealMatrix normalizedM) {
  180.         RealMatrix covarianceM = new Covariance(normalizedM).getCovarianceMatrix();
  181.         EigenDecompositionSymmetric decomposition = new EigenDecompositionSymmetric(covarianceM);
  182.         eigenValues = decomposition.getEigenvalues();
  183.         principalComponents = MatrixUtils.createRealMatrix(eigenValues.length, numC);
  184.         for (int c = 0; c < numC; c++) {
  185.             for (int f = 0; f < eigenValues.length; f++) {
  186.                 principalComponents.setEntry(f, c, decomposition.getEigenvector(c).getEntry(f));
  187.             }
  188.         }
  189.     }

  190.     /**
  191.      * This will either normalize (center) or
  192.      * standardize (center plus scale) the input data.
  193.      *
  194.      * @param input the input data
  195.      * @return the normalized (or standardized) matrix
  196.      */
  197.     private RealMatrix getNormalizedMatrix(double[][] input) {
  198.         int numS = input.length;
  199.         int numF = input[0].length;
  200.         boolean calculating = center == null;
  201.         if (calculating) {
  202.             center = new double[numF];
  203.             if (scale) {
  204.                 std = new double[numF];
  205.             }
  206.         }

  207.         double[][] normalized = new double[numS][numF];
  208.         for (int f = 0; f < numF; f++) {
  209.             if (calculating) {
  210.                 calculateNormalizeParameters(input, numS, f);
  211.             }
  212.             for (int s = 0; s < numS; s++) {
  213.                 normalized[s][f] = input[s][f] - center[f];
  214.             }
  215.             if (scale) {
  216.                 for (int s = 0; s < numS; s++) {
  217.                     normalized[s][f] /= std[f];
  218.                 }
  219.             }
  220.         }

  221.         return MatrixUtils.createRealMatrix(normalized);
  222.     }

  223.     /** compute normalized parameters.
  224.      * @param input the input data
  225.      * @param numS number of data rows
  226.      * @param f index of the component
  227.      */
  228.     private void calculateNormalizeParameters(double[][] input, int numS, int f) {
  229.         double[] column = new double[numS];
  230.         for (int s = 0; s < numS; s++) {
  231.             column[s] = input[s][f];
  232.         }
  233.         center[f] = StatUtils.mean(column);
  234.         if (scale) {
  235.             std[f] = sd.evaluate(column, center[f]);
  236.         }
  237.     }
  238. }