/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.classifier;

import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import org.apache.mahout.classifier.ConfusionMatrix;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.math.Matrix;
import org.junit.Test;

public final class ConfusionMatrixTest
extends MahoutTestCase {
    private static final int[][] VALUES = new int[][]{{2, 3}, {10, 20}};
    private static final String[] LABELS = new String[]{"Label1", "Label2"};
    private static final int[] OTHER = new int[]{3, 6};
    private static final String DEFAULT_LABEL = "other";

    @Test
    public void testBuild() {
        ConfusionMatrix confusionMatrix = ConfusionMatrixTest.fillConfusionMatrix(VALUES, LABELS, DEFAULT_LABEL);
        ConfusionMatrixTest.checkValues(confusionMatrix);
        ConfusionMatrixTest.checkAccuracy(confusionMatrix);
    }

    @Test
    public void testGetMatrix() {
        ConfusionMatrix confusionMatrix = ConfusionMatrixTest.fillConfusionMatrix(VALUES, LABELS, DEFAULT_LABEL);
        Matrix m = confusionMatrix.getMatrix();
        Map rowLabels = m.getRowLabelBindings();
        ConfusionMatrixTest.assertEquals((long)confusionMatrix.getLabels().size(), (long)m.numCols());
        ConfusionMatrixTest.assertTrue((boolean)rowLabels.keySet().contains(LABELS[0]));
        ConfusionMatrixTest.assertTrue((boolean)rowLabels.keySet().contains(LABELS[1]));
        ConfusionMatrixTest.assertTrue((boolean)rowLabels.keySet().contains(DEFAULT_LABEL));
        ConfusionMatrixTest.assertEquals((long)2L, (long)confusionMatrix.getCorrect(LABELS[0]));
        ConfusionMatrixTest.assertEquals((long)20L, (long)confusionMatrix.getCorrect(LABELS[1]));
        ConfusionMatrixTest.assertEquals((long)0L, (long)confusionMatrix.getCorrect(DEFAULT_LABEL));
    }

    @Test
    public void testPrecisionRecallAndF1ScoreAsScikitLearn() {
        List<String> labelList = Arrays.asList("0", "1", "2");
        ConfusionMatrix confusionMatrix = new ConfusionMatrix(labelList, "DEFAULT");
        confusionMatrix.putCount("0", "0", 2);
        confusionMatrix.putCount("1", "0", 1);
        confusionMatrix.putCount("1", "2", 1);
        confusionMatrix.putCount("2", "1", 2);
        double delta = 0.001;
        ConfusionMatrixTest.assertEquals((double)0.222, (double)confusionMatrix.getWeightedPrecision(), (double)delta);
        ConfusionMatrixTest.assertEquals((double)0.333, (double)confusionMatrix.getWeightedRecall(), (double)delta);
        ConfusionMatrixTest.assertEquals((double)0.266, (double)confusionMatrix.getWeightedF1score(), (double)delta);
    }

    private static void checkValues(ConfusionMatrix cm) {
        int[][] counts = cm.getConfusionMatrix();
        cm.toString();
        ConfusionMatrixTest.assertEquals((long)counts.length, (long)counts[0].length);
        ConfusionMatrixTest.assertEquals((long)3L, (long)counts.length);
        ConfusionMatrixTest.assertEquals((long)VALUES[0][0], (long)counts[0][0]);
        ConfusionMatrixTest.assertEquals((long)VALUES[0][1], (long)counts[0][1]);
        ConfusionMatrixTest.assertEquals((long)VALUES[1][0], (long)counts[1][0]);
        ConfusionMatrixTest.assertEquals((long)VALUES[1][1], (long)counts[1][1]);
        ConfusionMatrixTest.assertTrue((boolean)Arrays.equals(new int[3], counts[2]));
        ConfusionMatrixTest.assertEquals((long)OTHER[0], (long)counts[0][2]);
        ConfusionMatrixTest.assertEquals((long)OTHER[1], (long)counts[1][2]);
        ConfusionMatrixTest.assertEquals((long)3L, (long)cm.getLabels().size());
        ConfusionMatrixTest.assertTrue((boolean)cm.getLabels().contains(LABELS[0]));
        ConfusionMatrixTest.assertTrue((boolean)cm.getLabels().contains(LABELS[1]));
        ConfusionMatrixTest.assertTrue((boolean)cm.getLabels().contains(DEFAULT_LABEL));
    }

    private static void checkAccuracy(ConfusionMatrix cm) {
        Collection labelstrs = cm.getLabels();
        ConfusionMatrixTest.assertEquals((long)3L, (long)labelstrs.size());
        ConfusionMatrixTest.assertEquals((double)25.0, (double)cm.getAccuracy("Label1"), (double)1.0E-6);
        ConfusionMatrixTest.assertEquals((double)55.5555555, (double)cm.getAccuracy("Label2"), (double)1.0E-6);
        ConfusionMatrixTest.assertTrue((boolean)Double.isNaN(cm.getAccuracy(DEFAULT_LABEL)));
    }

    private static ConfusionMatrix fillConfusionMatrix(int[][] values, String[] labels, String defaultLabel) {
        ArrayList labelList = Lists.newArrayList();
        labelList.add(labels[0]);
        labelList.add(labels[1]);
        ConfusionMatrix confusionMatrix = new ConfusionMatrix((Collection)labelList, defaultLabel);
        confusionMatrix.putCount("Label1", "Label1", values[0][0]);
        confusionMatrix.putCount("Label1", "Label2", values[0][1]);
        confusionMatrix.putCount("Label2", "Label1", values[1][0]);
        confusionMatrix.putCount("Label2", "Label2", values[1][1]);
        confusionMatrix.putCount("Label1", DEFAULT_LABEL, OTHER[0]);
        confusionMatrix.putCount("Label2", DEFAULT_LABEL, OTHER[1]);
        return confusionMatrix;
    }
}

