/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/*
 * This is not the original file distributed by the Apache Software Foundation
 * It has been modified by the Hipparchus project
 */
package org.hipparchus.stat.descriptive;

import org.hipparchus.UnitTestUtils;
import org.hipparchus.exception.MathIllegalArgumentException;
import org.hipparchus.stat.descriptive.moment.GeometricMean;
import org.hipparchus.stat.descriptive.moment.Mean;
import org.hipparchus.stat.descriptive.moment.Variance;
import org.hipparchus.stat.descriptive.rank.Max;
import org.hipparchus.stat.descriptive.rank.Min;
import org.hipparchus.stat.descriptive.summary.Sum;
import org.hipparchus.stat.descriptive.summary.SumOfSquares;
import org.hipparchus.util.FastMath;
import org.hipparchus.util.Precision;
import org.junit.jupiter.api.Test;

import java.util.Arrays;
import java.util.Locale;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;

/**
 * Test cases for the {@link DescriptiveStatistics} class.
 */
public class DescriptiveStatisticsTest {

    private final double[] testArray = new double[] { 1, 2, 2, 3 };

    private final double mean = 2;
    private final double sumSq = 18;
    private final double sum = 8;
    private final double var = 0.666666666666666666667;
    private final double popVar = 0.5;
    private final double std = FastMath.sqrt(var);
    private final double n = 4;
    private final double min = 1;
    private final double max = 3;
    private final double tolerance = 10E-15;

    protected DescriptiveStatistics createDescriptiveStatistics() {
        return new DescriptiveStatistics();
    }

    /** test stats */
    @Test
    void testStats() {
        DescriptiveStatistics u = createDescriptiveStatistics();
        assertEquals(0, u.getN(), tolerance, "total count");
        double one = 1;
        u.addValue(one);
        float twoF = 2;
        u.addValue(twoF);
        long twoL = 2;
        u.addValue(twoL);
        int three = 3;
        u.addValue(three);
        assertEquals(n, u.getN(), tolerance, "N");
        assertEquals(sum, u.getSum(), tolerance, "sum");
        assertEquals(sumSq, u.getSumOfSquares(), tolerance, "sumsq");
        assertEquals(var, u.getVariance(), tolerance, "var");
        assertEquals(popVar, u.getPopulationVariance(), tolerance, "population var");
        assertEquals(std, u.getStandardDeviation(), tolerance, "std");
        assertEquals(mean, u.getMean(), tolerance, "mean");
        assertEquals(min, u.getMin(), tolerance, "min");
        assertEquals(max, u.getMax(), tolerance, "max");
        u.clear();
        assertEquals(0, u.getN(), tolerance, "total count");
    }

    @Test
    void testConsume() {
        DescriptiveStatistics u = createDescriptiveStatistics();
        assertEquals(0, u.getN(), tolerance, "total count");

        Arrays.stream(testArray)
              .forEach(u);

        assertEquals(n, u.getN(), tolerance, "N");
        assertEquals(sum, u.getSum(), tolerance, "sum");
        assertEquals(sumSq, u.getSumOfSquares(), tolerance, "sumsq");
        assertEquals(var, u.getVariance(), tolerance, "var");
        assertEquals(popVar, u.getPopulationVariance(), tolerance, "population var");
        assertEquals(std, u.getStandardDeviation(), tolerance, "std");
        assertEquals(mean, u.getMean(), tolerance, "mean");
        assertEquals(min, u.getMin(), tolerance, "min");
        assertEquals(max, u.getMax(), tolerance, "max");
        u.clear();
        assertEquals(0, u.getN(), tolerance, "total count");
    }

    @Test
    void testCopy() {
        DescriptiveStatistics stats = createDescriptiveStatistics();
        stats.addValue(1);
        stats.addValue(3);
        assertEquals(2, stats.getMean(), 1E-10);
        DescriptiveStatistics copy = stats.copy();
        assertEquals(2, copy.getMean(), 1E-10);
    }

    @Test
    void testWindowSize() {
        DescriptiveStatistics stats = createDescriptiveStatistics();
        stats.setWindowSize(300);
        for (int i = 0; i < 100; ++i) {
            stats.addValue(i + 1);
        }
        int refSum = (100 * 101) / 2;
        assertEquals(refSum / 100.0, stats.getMean(), 1E-10);
        assertEquals(300, stats.getWindowSize());
        try {
            stats.setWindowSize(-3);
            fail("an exception should have been thrown");
        } catch (MathIllegalArgumentException iae) {
            // expected
        }
        assertEquals(300, stats.getWindowSize());
        stats.setWindowSize(50);
        assertEquals(50, stats.getWindowSize());
        int refSum2 = refSum - (50 * 51) / 2;
        assertEquals(refSum2 / 50.0, stats.getMean(), 1E-10);
    }

    @Test
    void testGetValues() {
        DescriptiveStatistics stats = createDescriptiveStatistics();
        for (int i = 100; i > 0; --i) {
            stats.addValue(i);
        }
        int refSum = (100 * 101) / 2;
        assertEquals(refSum / 100.0, stats.getMean(), 1E-10);
        double[] v = stats.getValues();
        for (int i = 0; i < v.length; ++i) {
            assertEquals(100.0 - i, v[i], 1.0e-10);
        }
        double[] s = stats.getSortedValues();
        for (int i = 0; i < s.length; ++i) {
            assertEquals(i + 1.0, s[i], 1.0e-10);
        }
        assertEquals(12.0, stats.getElement(88), 1.0e-10);
    }

    @Test
    void testQuadraticMean() {
        final double[] values = { 1.2, 3.4, 5.6, 7.89 };
        final DescriptiveStatistics stats = new DescriptiveStatistics(values);

        final int len = values.length;
        double expected = 0;
        for (int i = 0; i < len; i++) {
            final double v = values[i];
            expected += v * v / len;
        }
        expected = Math.sqrt(expected);

        assertEquals(expected, stats.getQuadraticMean(), Math.ulp(expected));
    }

    @Test
    void testToString() {
        DescriptiveStatistics stats = createDescriptiveStatistics();
        stats.addValue(1);
        stats.addValue(2);
        stats.addValue(3);
        Locale d = Locale.getDefault();
        Locale.setDefault(Locale.US);
        assertEquals("DescriptiveStatistics:\n" +
                     "n: 3\n" +
                     "min: 1.0\n" +
                     "max: 3.0\n" +
                     "mean: 2.0\n" +
                     "std dev: 1.0\n" +
                     "median: 2.0\n" +
                     "skewness: 0.0\n" +
                     "kurtosis: NaN\n",  stats.toString());
        Locale.setDefault(d);
    }

    @Test
    void testPercentile() {
        DescriptiveStatistics stats = createDescriptiveStatistics();

        stats.addValue(1);
        stats.addValue(2);
        stats.addValue(3);
        assertEquals(2, stats.getPercentile(50.0), 1E-10);
    }

    @Test
    void test20090720() {
        DescriptiveStatistics descriptiveStatistics = new DescriptiveStatistics(100);
        for (int i = 0; i < 161; i++) {
            descriptiveStatistics.addValue(1.2);
        }
        descriptiveStatistics.clear();
        descriptiveStatistics.addValue(1.2);
        assertEquals(1, descriptiveStatistics.getN());
    }

    @Test
    void testRemoval() {
        final DescriptiveStatistics dstat = createDescriptiveStatistics();

        checkRemoval(dstat, 1, 6.0, 0.0, Double.NaN);
        checkRemoval(dstat, 3, 5.0, 3.0, 4.5);
        checkRemoval(dstat, 6, 3.5, 2.5, 3.0);
        checkRemoval(dstat, 9, 3.5, 2.5, 3.0);
        checkRemoval(dstat, DescriptiveStatistics.INFINITE_WINDOW, 3.5, 2.5, 3.0);
    }

    @Test
    void testSummaryConsistency() {
        final int windowSize = 5;
        final DescriptiveStatistics dstats = new DescriptiveStatistics(windowSize);
        final StreamingStatistics sstats = new StreamingStatistics();
        final double tol = 1E-12;
        for (int i = 0; i < 20; i++) {
            dstats.addValue(i);
            sstats.clear();
            double[] values = dstats.getValues();
            for (int j = 0; j < values.length; j++) {
                sstats.addValue(values[j]);
            }
            UnitTestUtils.customAssertEquals(dstats.getMean(), sstats.getMean(), tol);
            UnitTestUtils.customAssertEquals(new Mean().evaluate(values), dstats.getMean(), tol);
            UnitTestUtils.customAssertEquals(dstats.getMax(), sstats.getMax(), tol);
            UnitTestUtils.customAssertEquals(new Max().evaluate(values), dstats.getMax(), tol);
            UnitTestUtils.customAssertEquals(dstats.getGeometricMean(), sstats.getGeometricMean(), tol);
            UnitTestUtils.customAssertEquals(new GeometricMean().evaluate(values), dstats.getGeometricMean(), tol);
            UnitTestUtils.customAssertEquals(dstats.getMin(), sstats.getMin(), tol);
            UnitTestUtils.customAssertEquals(new Min().evaluate(values), dstats.getMin(), tol);
            UnitTestUtils.customAssertEquals(dstats.getStandardDeviation(), sstats.getStandardDeviation(), tol);
            UnitTestUtils.customAssertEquals(dstats.getVariance(), sstats.getVariance(), tol);
            UnitTestUtils.customAssertEquals(new Variance().evaluate(values), dstats.getVariance(), tol);
            UnitTestUtils.customAssertEquals(dstats.getSum(), sstats.getSum(), tol);
            UnitTestUtils.customAssertEquals(new Sum().evaluate(values), dstats.getSum(), tol);
            UnitTestUtils.customAssertEquals(dstats.getSumOfSquares(), sstats.getSumOfSquares(), tol);
            UnitTestUtils.customAssertEquals(new SumOfSquares().evaluate(values), dstats.getSumOfSquares(), tol);
            UnitTestUtils.customAssertEquals(dstats.getPopulationVariance(), sstats.getPopulationVariance(), tol);
            UnitTestUtils.customAssertEquals(new Variance(false).evaluate(values), dstats.getPopulationVariance(), tol);
        }
    }

    @Test
    void testMath1129(){
        final double[] data = new double[] {
            -0.012086732064244697,
            -0.24975668704012527,
            0.5706168483164684,
            -0.322111769955327,
            0.24166759508327315,
            Double.NaN,
            0.16698443218942854,
            -0.10427763937565114,
            -0.15595963093172435,
            -0.028075857595882995,
            -0.24137994506058857,
            0.47543170476574426,
            -0.07495595384947631,
            0.37445697625436497,
            -0.09944199541668033
        };

        final DescriptiveStatistics ds = new DescriptiveStatistics(data);

        final double t = ds.getPercentile(75);
        final double o = ds.getPercentile(25);

        final double iqr = t - o;
        // System.out.println(String.format("25th percentile %s 75th percentile %s", o, t));
        assertTrue(iqr >= 0);
    }

    public void checkRemoval(DescriptiveStatistics dstat, int wsize,
                             double mean1, double mean2, double mean3) {

        dstat.setWindowSize(wsize);
        dstat.clear();

        for (int i = 1 ; i <= 6 ; ++i) {
            dstat.addValue(i);
        }

        assertTrue(Precision.equalsIncludingNaN(mean1, dstat.getMean()));
        dstat.replaceMostRecentValue(0);
        assertTrue(Precision.equalsIncludingNaN(mean2, dstat.getMean()));
        dstat.removeMostRecentValue();
        assertTrue(Precision.equalsIncludingNaN(mean3, dstat.getMean()));
    }

}
