package org.apache.mahout.math.stats;

import java.util.Random;
import org.apache.mahout.classifier.evaluation.Auc;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.stats.GlobalOnlineAuc;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/math/stats/OnlineAucTest.class */
public final class OnlineAucTest extends MahoutTestCase {
    @Test
    public void testBinaryCase() {
        Random random = RandomUtils.getRandom();
        OnlineSummarizer[] onlineSummarizerArr = new OnlineSummarizer[4];
        for (int i = 0; i < 4; i++) {
            onlineSummarizerArr[i] = new OnlineSummarizer();
        }
        for (int i2 = 0; i2 < 100; i2++) {
            GlobalOnlineAuc globalOnlineAuc = new GlobalOnlineAuc();
            globalOnlineAuc.setPolicy(GlobalOnlineAuc.ReplacementPolicy.FAIR);
            GlobalOnlineAuc globalOnlineAuc2 = new GlobalOnlineAuc();
            globalOnlineAuc2.setPolicy(GlobalOnlineAuc.ReplacementPolicy.FIFO);
            GlobalOnlineAuc globalOnlineAuc3 = new GlobalOnlineAuc();
            globalOnlineAuc3.setPolicy(GlobalOnlineAuc.ReplacementPolicy.RANDOM);
            Auc auc = new Auc();
            for (int i3 = 0; i3 < 10000; i3++) {
                double nextGaussian = random.nextGaussian();
                globalOnlineAuc.addSample(0, nextGaussian);
                globalOnlineAuc2.addSample(0, nextGaussian);
                globalOnlineAuc3.addSample(0, nextGaussian);
                auc.add(0, nextGaussian);
                double nextGaussian2 = random.nextGaussian() + 1.0d;
                globalOnlineAuc.addSample(1, nextGaussian2);
                globalOnlineAuc2.addSample(1, nextGaussian2);
                globalOnlineAuc3.addSample(1, nextGaussian2);
                auc.add(1, nextGaussian2);
            }
            onlineSummarizerArr[0].add(globalOnlineAuc.auc());
            onlineSummarizerArr[1].add(globalOnlineAuc2.auc());
            onlineSummarizerArr[2].add(globalOnlineAuc3.auc());
            onlineSummarizerArr[3].add(auc.auc());
        }
        int i4 = 0;
        for (GlobalOnlineAuc.ReplacementPolicy replacementPolicy : new GlobalOnlineAuc.ReplacementPolicy[]{GlobalOnlineAuc.ReplacementPolicy.FAIR, GlobalOnlineAuc.ReplacementPolicy.FIFO, GlobalOnlineAuc.ReplacementPolicy.RANDOM, null}) {
            int i5 = i4;
            i4++;
            OnlineSummarizer onlineSummarizer = onlineSummarizerArr[i5];
            System.out.printf("%s,%.4f (min = %.4f, 25%%-ile=%.4f, 75%%-ile=%.4f, max=%.4f)\n", replacementPolicy, Double.valueOf(onlineSummarizer.getMean()), Double.valueOf(onlineSummarizer.getQuartile(0)), Double.valueOf(onlineSummarizer.getQuartile(1)), Double.valueOf(onlineSummarizer.getQuartile(2)), Double.valueOf(onlineSummarizer.getQuartile(3)));
        }
        assertEquals(0.7603d, onlineSummarizerArr[0].getMean(), 0.03d);
        assertEquals(0.7603d, onlineSummarizerArr[0].getQuartile(1), 0.03d);
        assertEquals(0.7603d, onlineSummarizerArr[0].getQuartile(3), 0.03d);
        assertEquals(0.7603d, onlineSummarizerArr[1].getMean(), 0.001d);
        assertEquals(0.7603d, onlineSummarizerArr[1].getQuartile(1), 0.006d);
        assertEquals(0.7603d, onlineSummarizerArr[1].getQuartile(3), 0.006d);
        assertEquals(0.7603d, onlineSummarizerArr[2].getMean(), 0.001d);
        assertEquals(0.7603d, onlineSummarizerArr[2].getQuartile(1), 0.006d);
        assertEquals(0.7603d, onlineSummarizerArr[2].getQuartile(1), 0.006d);
    }

    @Test(expected = UnsupportedOperationException.class)
    public void mustNotOmitGroup() {
        new GroupedOnlineAuc().addSample(0, 3.14d);
    }

    @Test
    public void groupedAuc() {
        Random random = RandomUtils.getRandom();
        GroupedOnlineAuc groupedOnlineAuc = new GroupedOnlineAuc();
        GlobalOnlineAuc globalOnlineAuc = new GlobalOnlineAuc();
        for (int i = 0; i < 10000; i++) {
            groupedOnlineAuc.addSample(0, "a", random.nextGaussian());
            groupedOnlineAuc.addSample(1, "a", random.nextGaussian() + 1.0d);
            groupedOnlineAuc.addSample(0, "b", random.nextGaussian() + 10.0d);
            groupedOnlineAuc.addSample(1, "b", random.nextGaussian() + 11.0d);
            globalOnlineAuc.addSample(0, "a", random.nextGaussian());
            globalOnlineAuc.addSample(1, "a", random.nextGaussian() + 1.0d);
            globalOnlineAuc.addSample(0, "b", random.nextGaussian() + 10.0d);
            globalOnlineAuc.addSample(1, "b", random.nextGaussian() + 11.0d);
        }
        assertEquals(0.7603d, groupedOnlineAuc.auc(), 0.01d);
        assertEquals(0.63015d, globalOnlineAuc.auc(), 0.02d);
    }
}
