/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.clustering;

import com.google.common.collect.Lists;
import java.util.Collection;
import org.apache.mahout.clustering.OnlineGaussianAccumulator;
import org.apache.mahout.clustering.RunningSumsGaussianAccumulator;
import org.apache.mahout.clustering.UncommonDistributions;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.function.DoubleFunction;
import org.apache.mahout.math.function.Functions;
import org.apache.mahout.math.function.SquareRootFunction;
import org.junit.Before;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class TestGaussianAccumulators
extends MahoutTestCase {
    private static final Logger log = LoggerFactory.getLogger(TestGaussianAccumulators.class);
    private Collection<VectorWritable> sampleData = Lists.newArrayList();
    private int sampleN;
    private Vector sampleMean;
    private Vector sampleStd;

    @Override
    @Before
    public void setUp() throws Exception {
        super.setUp();
        this.sampleData = Lists.newArrayList();
        this.generateSamples();
        this.sampleN = 0;
        DenseVector sum = new DenseVector(2);
        for (VectorWritable v : this.sampleData) {
            sum.assign(v.get(), Functions.PLUS);
            ++this.sampleN;
        }
        this.sampleMean = sum.divide((double)this.sampleN);
        DenseVector sampleVar = new DenseVector(2);
        for (VectorWritable v : this.sampleData) {
            Vector delta = v.get().minus(this.sampleMean);
            sampleVar.assign(delta.times(delta), Functions.PLUS);
        }
        sampleVar = sampleVar.divide((double)(this.sampleN - 1));
        this.sampleStd = sampleVar.clone();
        this.sampleStd.assign((DoubleFunction)new SquareRootFunction());
        log.info("Observing {} samples m=[{}, {}] sd=[{}, {}]", new Object[]{this.sampleN, this.sampleMean.get(0), this.sampleMean.get(1), this.sampleStd.get(0), this.sampleStd.get(1)});
    }

    private void generate2dSamples(int num, double mx, double my, double sdx, double sdy) {
        log.info("Generating {} samples m=[{}, {}] sd=[{}, {}]", new Object[]{num, mx, my, sdx, sdy});
        for (int i = 0; i < num; ++i) {
            this.sampleData.add(new VectorWritable((Vector)new DenseVector(new double[]{UncommonDistributions.rNorm((double)mx, (double)sdx), UncommonDistributions.rNorm((double)my, (double)sdy)})));
        }
    }

    private void generateSamples() {
        this.generate2dSamples(50000, 1.0, 2.0, 3.0, 4.0);
    }

    @Test
    public void testAccumulatorNoSamples() {
        RunningSumsGaussianAccumulator accumulator0 = new RunningSumsGaussianAccumulator();
        OnlineGaussianAccumulator accumulator1 = new OnlineGaussianAccumulator();
        accumulator0.compute();
        accumulator1.compute();
        TestGaussianAccumulators.assertEquals((String)"N", (double)accumulator0.getN(), (double)accumulator1.getN(), (double)1.0E-6);
        TestGaussianAccumulators.assertEquals((String)"Means", (Object)accumulator0.getMean(), (Object)accumulator1.getMean());
        TestGaussianAccumulators.assertEquals((String)"Avg Stds", (double)accumulator0.getAverageStd(), (double)accumulator1.getAverageStd(), (double)1.0E-6);
    }

    @Test
    public void testAccumulatorOneSample() {
        RunningSumsGaussianAccumulator accumulator0 = new RunningSumsGaussianAccumulator();
        OnlineGaussianAccumulator accumulator1 = new OnlineGaussianAccumulator();
        DenseVector sample = new DenseVector(2);
        accumulator0.observe((Vector)sample, 1.0);
        accumulator1.observe((Vector)sample, 1.0);
        accumulator0.compute();
        accumulator1.compute();
        TestGaussianAccumulators.assertEquals((String)"N", (double)accumulator0.getN(), (double)accumulator1.getN(), (double)1.0E-6);
        TestGaussianAccumulators.assertEquals((String)"Means", (Object)accumulator0.getMean(), (Object)accumulator1.getMean());
        TestGaussianAccumulators.assertEquals((String)"Avg Stds", (double)accumulator0.getAverageStd(), (double)accumulator1.getAverageStd(), (double)1.0E-6);
    }

    @Test
    public void testOLAccumulatorResults() {
        OnlineGaussianAccumulator accumulator = new OnlineGaussianAccumulator();
        for (VectorWritable vw : this.sampleData) {
            accumulator.observe(vw.get(), 1.0);
        }
        accumulator.compute();
        log.info("OL Observed {} samples m=[{}, {}] sd=[{}, {}]", new Object[]{accumulator.getN(), accumulator.getMean().get(0), accumulator.getMean().get(1), accumulator.getStd().get(0), accumulator.getStd().get(1)});
        TestGaussianAccumulators.assertEquals((String)"OL N", (double)this.sampleN, (double)accumulator.getN(), (double)1.0E-6);
        TestGaussianAccumulators.assertEquals((String)"OL Mean", (double)this.sampleMean.zSum(), (double)accumulator.getMean().zSum(), (double)1.0E-6);
        TestGaussianAccumulators.assertEquals((String)"OL Std", (double)this.sampleStd.zSum(), (double)accumulator.getStd().zSum(), (double)1.0E-6);
    }

    @Test
    public void testRSAccumulatorResults() {
        RunningSumsGaussianAccumulator accumulator = new RunningSumsGaussianAccumulator();
        for (VectorWritable vw : this.sampleData) {
            accumulator.observe(vw.get(), 1.0);
        }
        accumulator.compute();
        log.info("RS Observed {} samples m=[{}, {}] sd=[{}, {}]", new Object[]{(int)accumulator.getN(), accumulator.getMean().get(0), accumulator.getMean().get(1), accumulator.getStd().get(0), accumulator.getStd().get(1)});
        TestGaussianAccumulators.assertEquals((String)"OL N", (double)this.sampleN, (double)accumulator.getN(), (double)1.0E-6);
        TestGaussianAccumulators.assertEquals((String)"OL Mean", (double)this.sampleMean.zSum(), (double)accumulator.getMean().zSum(), (double)1.0E-6);
        TestGaussianAccumulators.assertEquals((String)"OL Std", (double)this.sampleStd.zSum(), (double)accumulator.getStd().zSum(), (double)1.0E-4);
    }

    @Test
    public void testAccumulatorWeightedResults() {
        RunningSumsGaussianAccumulator accumulator0 = new RunningSumsGaussianAccumulator();
        OnlineGaussianAccumulator accumulator1 = new OnlineGaussianAccumulator();
        for (VectorWritable vw : this.sampleData) {
            accumulator0.observe(vw.get(), 0.5);
            accumulator1.observe(vw.get(), 0.5);
        }
        accumulator0.compute();
        accumulator1.compute();
        TestGaussianAccumulators.assertEquals((String)"N", (double)accumulator0.getN(), (double)accumulator1.getN(), (double)1.0E-6);
        TestGaussianAccumulators.assertEquals((String)"Means", (double)accumulator0.getMean().zSum(), (double)accumulator1.getMean().zSum(), (double)1.0E-6);
        TestGaussianAccumulators.assertEquals((String)"Stds", (double)accumulator0.getStd().zSum(), (double)accumulator1.getStd().zSum(), (double)0.001);
        TestGaussianAccumulators.assertEquals((String)"Variance", (double)accumulator0.getVariance().zSum(), (double)accumulator1.getVariance().zSum(), (double)0.01);
    }

    @Test
    public void testAccumulatorWeightedResults2() {
        RunningSumsGaussianAccumulator accumulator0 = new RunningSumsGaussianAccumulator();
        OnlineGaussianAccumulator accumulator1 = new OnlineGaussianAccumulator();
        for (VectorWritable vw : this.sampleData) {
            accumulator0.observe(vw.get(), 1.5);
            accumulator1.observe(vw.get(), 1.5);
        }
        accumulator0.compute();
        accumulator1.compute();
        TestGaussianAccumulators.assertEquals((String)"N", (double)accumulator0.getN(), (double)accumulator1.getN(), (double)1.0E-6);
        TestGaussianAccumulators.assertEquals((String)"Means", (double)accumulator0.getMean().zSum(), (double)accumulator1.getMean().zSum(), (double)1.0E-6);
        TestGaussianAccumulators.assertEquals((String)"Stds", (double)accumulator0.getStd().zSum(), (double)accumulator1.getStd().zSum(), (double)0.001);
        TestGaussianAccumulators.assertEquals((String)"Variance", (double)accumulator0.getVariance().zSum(), (double)accumulator1.getVariance().zSum(), (double)0.01);
    }
}

