MultivariateSummaryStatistics.java

  1. /*
  2.  * Licensed to the Apache Software Foundation (ASF) under one or more
  3.  * contributor license agreements.  See the NOTICE file distributed with
  4.  * this work for additional information regarding copyright ownership.
  5.  * The ASF licenses this file to You under the Apache License, Version 2.0
  6.  * (the "License"); you may not use this file except in compliance with
  7.  * the License.  You may obtain a copy of the License at
  8.  *
  9.  *      https://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */

  17. /*
  18.  * This is not the original file distributed by the Apache Software Foundation
  19.  * It has been modified by the Hipparchus project
  20.  */
  21. package org.hipparchus.stat.descriptive;

  22. import java.io.Serializable;
  23. import java.util.Arrays;

  24. import org.hipparchus.exception.LocalizedCoreFormats;
  25. import org.hipparchus.exception.MathIllegalArgumentException;
  26. import org.hipparchus.linear.RealMatrix;
  27. import org.hipparchus.stat.descriptive.moment.GeometricMean;
  28. import org.hipparchus.stat.descriptive.moment.Mean;
  29. import org.hipparchus.stat.descriptive.rank.Max;
  30. import org.hipparchus.stat.descriptive.rank.Min;
  31. import org.hipparchus.stat.descriptive.summary.Sum;
  32. import org.hipparchus.stat.descriptive.summary.SumOfLogs;
  33. import org.hipparchus.stat.descriptive.summary.SumOfSquares;
  34. import org.hipparchus.stat.descriptive.vector.VectorialCovariance;
  35. import org.hipparchus.stat.descriptive.vector.VectorialStorelessStatistic;
  36. import org.hipparchus.util.FastMath;
  37. import org.hipparchus.util.MathArrays;
  38. import org.hipparchus.util.MathUtils;

  39. /**
  40.  * Computes summary statistics for a stream of n-tuples added using the
  41.  * {@link #addValue(double[]) addValue} method. The data values are not stored
  42.  * in memory, so this class can be used to compute statistics for very large
  43.  * n-tuple streams.
  44.  * <p>
  45.  * To compute statistics for a stream of n-tuples, construct a
  46.  * {@link MultivariateSummaryStatistics} instance with dimension n and then use
  47.  * {@link #addValue(double[])} to add n-tuples. The <code>getXxx</code>
  48.  * methods where Xxx is a statistic return an array of <code>double</code>
  49.  * values, where for <code>i = 0,...,n-1</code> the i<sup>th</sup> array element
  50.  * is the value of the given statistic for data range consisting of the i<sup>th</sup>
  51.  * element of each of the input n-tuples.  For example, if <code>addValue</code> is
  52.  * called with actual parameters {0, 1, 2}, then {3, 4, 5} and finally {6, 7, 8},
  53.  * <code>getSum</code> will return a three-element array with values {0+3+6, 1+4+7, 2+5+8}
  54.  * <p>
  55.  * Note: This class is not thread-safe.
  56.  */
  57. public class MultivariateSummaryStatistics
  58.     implements StatisticalMultivariateSummary, Serializable {

  59.     /** Serialization UID */
  60.     private static final long serialVersionUID = 20160424L;

  61.     /** Dimension of the data. */
  62.     private final int k;

  63.     /** Sum statistic implementation */
  64.     private final StorelessMultivariateStatistic sumImpl;
  65.     /** Sum of squares statistic implementation */
  66.     private final StorelessMultivariateStatistic sumSqImpl;
  67.     /** Minimum statistic implementation */
  68.     private final StorelessMultivariateStatistic minImpl;
  69.     /** Maximum statistic implementation */
  70.     private final StorelessMultivariateStatistic maxImpl;
  71.     /** Sum of log statistic implementation */
  72.     private final StorelessMultivariateStatistic sumLogImpl;
  73.     /** Geometric mean statistic implementation */
  74.     private final StorelessMultivariateStatistic geoMeanImpl;
  75.     /** Mean statistic implementation */
  76.     private final StorelessMultivariateStatistic meanImpl;
  77.     /** Covariance statistic implementation */
  78.     private final VectorialCovariance covarianceImpl;

  79.     /** Count of values that have been added */
  80.     private long n;

  81.     /**
  82.      * Construct a MultivariateSummaryStatistics instance for the given
  83.      * dimension. The returned instance will compute the unbiased sample
  84.      * covariance.
  85.      * <p>
  86.      * The returned instance is <b>not</b> thread-safe.
  87.      *
  88.      * @param dimension dimension of the data
  89.      */
  90.     public MultivariateSummaryStatistics(int dimension) {
  91.         this(dimension, true);
  92.     }

  93.     /**
  94.      * Construct a MultivariateSummaryStatistics instance for the given
  95.      * dimension.
  96.      * <p>
  97.      * The returned instance is <b>not</b> thread-safe.
  98.      *
  99.      * @param dimension dimension of the data
  100.      * @param covarianceBiasCorrection if true, the returned instance will compute
  101.      * the unbiased sample covariance, otherwise the population covariance
  102.      */
  103.     public MultivariateSummaryStatistics(int dimension, boolean covarianceBiasCorrection) {
  104.         this.k = dimension;

  105.         sumImpl     = new VectorialStorelessStatistic(k, new Sum());
  106.         sumSqImpl   = new VectorialStorelessStatistic(k, new SumOfSquares());
  107.         minImpl     = new VectorialStorelessStatistic(k, new Min());
  108.         maxImpl     = new VectorialStorelessStatistic(k, new Max());
  109.         sumLogImpl  = new VectorialStorelessStatistic(k, new SumOfLogs());
  110.         geoMeanImpl = new VectorialStorelessStatistic(k, new GeometricMean());
  111.         meanImpl    = new VectorialStorelessStatistic(k, new Mean());

  112.         covarianceImpl = new VectorialCovariance(k, covarianceBiasCorrection);
  113.     }

  114.     /**
  115.      * Add an n-tuple to the data
  116.      *
  117.      * @param value  the n-tuple to add
  118.      * @throws MathIllegalArgumentException if the array is null or the length
  119.      * of the array does not match the one used at construction
  120.      */
  121.     public void addValue(double[] value) throws MathIllegalArgumentException {
  122.         MathUtils.checkNotNull(value, LocalizedCoreFormats.INPUT_ARRAY);
  123.         MathUtils.checkDimension(value.length, k);
  124.         sumImpl.increment(value);
  125.         sumSqImpl.increment(value);
  126.         minImpl.increment(value);
  127.         maxImpl.increment(value);
  128.         sumLogImpl.increment(value);
  129.         geoMeanImpl.increment(value);
  130.         meanImpl.increment(value);
  131.         covarianceImpl.increment(value);
  132.         n++;
  133.     }

  134.     /**
  135.      * Resets all statistics and storage.
  136.      */
  137.     public void clear() {
  138.         this.n = 0;
  139.         minImpl.clear();
  140.         maxImpl.clear();
  141.         sumImpl.clear();
  142.         sumLogImpl.clear();
  143.         sumSqImpl.clear();
  144.         geoMeanImpl.clear();
  145.         meanImpl.clear();
  146.         covarianceImpl.clear();
  147.     }

  148.     /** {@inheritDoc} **/
  149.     @Override
  150.     public int getDimension() {
  151.         return k;
  152.     }

  153.     /** {@inheritDoc} **/
  154.     @Override
  155.     public long getN() {
  156.         return n;
  157.     }

  158.     /** {@inheritDoc} **/
  159.     @Override
  160.     public double[] getSum() {
  161.         return sumImpl.getResult();
  162.     }

  163.     /** {@inheritDoc} **/
  164.     @Override
  165.     public double[] getSumSq() {
  166.         return sumSqImpl.getResult();
  167.     }

  168.     /** {@inheritDoc} **/
  169.     @Override
  170.     public double[] getSumLog() {
  171.         return sumLogImpl.getResult();
  172.     }

  173.     /** {@inheritDoc} **/
  174.     @Override
  175.     public double[] getMean() {
  176.         return meanImpl.getResult();
  177.     }

  178.     /** {@inheritDoc} **/
  179.     @Override
  180.     public RealMatrix getCovariance() {
  181.         return covarianceImpl.getResult();
  182.     }

  183.     /** {@inheritDoc} **/
  184.     @Override
  185.     public double[] getMax() {
  186.         return maxImpl.getResult();
  187.     }

  188.     /** {@inheritDoc} **/
  189.     @Override
  190.     public double[] getMin() {
  191.         return minImpl.getResult();
  192.     }

  193.     /** {@inheritDoc} **/
  194.     @Override
  195.     public double[] getGeometricMean() {
  196.         return geoMeanImpl.getResult();
  197.     }

  198.     /**
  199.      * Returns an array whose i<sup>th</sup> entry is the standard deviation of the
  200.      * i<sup>th</sup> entries of the arrays that have been added using
  201.      * {@link #addValue(double[])}
  202.      *
  203.      * @return the array of component standard deviations
  204.      */
  205.     @Override
  206.     public double[] getStandardDeviation() {
  207.         double[] stdDev = new double[k];
  208.         if (getN() < 1) {
  209.             Arrays.fill(stdDev, Double.NaN);
  210.         } else if (getN() < 2) {
  211.             Arrays.fill(stdDev, 0.0);
  212.         } else {
  213.             RealMatrix matrix = getCovariance();
  214.             for (int i = 0; i < k; ++i) {
  215.                 stdDev[i] = FastMath.sqrt(matrix.getEntry(i, i));
  216.             }
  217.         }
  218.         return stdDev;
  219.     }

  220.     /**
  221.      * Generates a text report displaying
  222.      * summary statistics from values that
  223.      * have been added.
  224.      * @return String with line feeds displaying statistics
  225.      */
  226.     @Override
  227.     public String toString() {
  228.         final String separator = ", ";
  229.         final String suffix = System.getProperty("line.separator");
  230.         StringBuilder outBuffer = new StringBuilder(200); // the size is just a wild guess
  231.         outBuffer.append("MultivariateSummaryStatistics:").append(suffix).
  232.                   append("n: ").append(getN()).append(suffix);
  233.         append(outBuffer, getMin(), "min: ", separator, suffix);
  234.         append(outBuffer, getMax(), "max: ", separator, suffix);
  235.         append(outBuffer, getMean(), "mean: ", separator, suffix);
  236.         append(outBuffer, getGeometricMean(), "geometric mean: ", separator, suffix);
  237.         append(outBuffer, getSumSq(), "sum of squares: ", separator, suffix);
  238.         append(outBuffer, getSumLog(), "sum of logarithms: ", separator, suffix);
  239.         append(outBuffer, getStandardDeviation(), "standard deviation: ", separator, suffix);
  240.         outBuffer.append("covariance: ").append(getCovariance().toString()).append(suffix);
  241.         return outBuffer.toString();
  242.     }

  243.     /**
  244.      * Append a text representation of an array to a buffer.
  245.      * @param buffer buffer to fill
  246.      * @param data data array
  247.      * @param prefix text prefix
  248.      * @param separator elements separator
  249.      * @param suffix text suffix
  250.      */
  251.     private void append(StringBuilder buffer, double[] data,
  252.                         String prefix, String separator, String suffix) {
  253.         buffer.append(prefix);
  254.         for (int i = 0; i < data.length; ++i) {
  255.             if (i > 0) {
  256.                 buffer.append(separator);
  257.             }
  258.             buffer.append(data[i]);
  259.         }
  260.         buffer.append(suffix);
  261.     }

  262.     /**
  263.      * Returns true iff <code>object</code> is a <code>MultivariateSummaryStatistics</code>
  264.      * instance and all statistics have the same values as this.
  265.      * @param object the object to test equality against.
  266.      * @return true if object equals this
  267.      */
  268.     @Override
  269.     public boolean equals(Object object) {
  270.         if (object == this) {
  271.             return true;
  272.         }
  273.         if (!(object instanceof MultivariateSummaryStatistics)) {
  274.             return false;
  275.         }
  276.         MultivariateSummaryStatistics other = (MultivariateSummaryStatistics) object;
  277.         return other.getN() == getN()                                                      &&
  278.                MathArrays.equalsIncludingNaN(other.getGeometricMean(), getGeometricMean()) &&
  279.                MathArrays.equalsIncludingNaN(other.getMax(),           getMax())           &&
  280.                MathArrays.equalsIncludingNaN(other.getMean(),          getMean())          &&
  281.                MathArrays.equalsIncludingNaN(other.getMin(),           getMin())           &&
  282.                MathArrays.equalsIncludingNaN(other.getSum(),           getSum())           &&
  283.                MathArrays.equalsIncludingNaN(other.getSumSq(),         getSumSq())         &&
  284.                MathArrays.equalsIncludingNaN(other.getSumLog(),        getSumLog())        &&
  285.                other.getCovariance().equals(getCovariance());
  286.     }

  287.     /**
  288.      * Returns hash code based on values of statistics
  289.      *
  290.      * @return hash code
  291.      */
  292.     @Override
  293.     public int hashCode() {
  294.         int result = 31 + MathUtils.hash(getN());
  295.         result = result * 31 + MathUtils.hash(getGeometricMean());
  296.         result = result * 31 + MathUtils.hash(getMax());
  297.         result = result * 31 + MathUtils.hash(getMean());
  298.         result = result * 31 + MathUtils.hash(getMin());
  299.         result = result * 31 + MathUtils.hash(getSum());
  300.         result = result * 31 + MathUtils.hash(getSumSq());
  301.         result = result * 31 + MathUtils.hash(getSumLog());
  302.         result = result * 31 + getCovariance().hashCode();
  303.         return result;
  304.     }

  305. }