package org.apache.mahout.classifier;

import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Map;
import org.apache.mahout.common.MahoutTestCase;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/classifier/ConfusionMatrixTest.class */
public final class ConfusionMatrixTest extends MahoutTestCase {
    private static final int[][] VALUES = {new int[]{2, 3}, new int[]{10, 20}};
    private static final String[] LABELS = {"Label1", "Label2"};
    private static final int[] OTHER = {3, 6};
    private static final String DEFAULT_LABEL = "other";

    @Test
    public void testBuild() {
        ConfusionMatrix fillConfusionMatrix = fillConfusionMatrix(VALUES, LABELS, DEFAULT_LABEL);
        checkValues(fillConfusionMatrix);
        checkAccuracy(fillConfusionMatrix);
    }

    @Test
    public void testGetMatrix() {
        Map rowLabelBindings = fillConfusionMatrix(VALUES, LABELS, DEFAULT_LABEL).getMatrix().getRowLabelBindings();
        assertEquals(r0.getLabels().size(), r0.numCols());
        assertTrue(rowLabelBindings.keySet().contains(LABELS[0]));
        assertTrue(rowLabelBindings.keySet().contains(LABELS[1]));
        assertTrue(rowLabelBindings.keySet().contains(DEFAULT_LABEL));
        assertEquals(2L, r0.getCorrect(LABELS[0]));
        assertEquals(20L, r0.getCorrect(LABELS[1]));
        assertEquals(0L, r0.getCorrect(DEFAULT_LABEL));
    }

    @Test
    public void testPrecisionRecallAndF1ScoreAsScikitLearn() {
        ConfusionMatrix confusionMatrix = new ConfusionMatrix(Arrays.asList("0", "1", "2"), "DEFAULT");
        confusionMatrix.putCount("0", "0", 2);
        confusionMatrix.putCount("1", "0", 1);
        confusionMatrix.putCount("1", "2", 1);
        confusionMatrix.putCount("2", "1", 2);
        assertEquals(0.222d, confusionMatrix.getWeightedPrecision(), 0.001d);
        assertEquals(0.333d, confusionMatrix.getWeightedRecall(), 0.001d);
        assertEquals(0.266d, confusionMatrix.getWeightedF1score(), 0.001d);
    }

    private static void checkValues(ConfusionMatrix confusionMatrix) {
        int[][] confusionMatrix2 = confusionMatrix.getConfusionMatrix();
        confusionMatrix.toString();
        assertEquals(confusionMatrix2.length, confusionMatrix2[0].length);
        assertEquals(3L, confusionMatrix2.length);
        assertEquals(VALUES[0][0], confusionMatrix2[0][0]);
        assertEquals(VALUES[0][1], confusionMatrix2[0][1]);
        assertEquals(VALUES[1][0], confusionMatrix2[1][0]);
        assertEquals(VALUES[1][1], confusionMatrix2[1][1]);
        assertTrue(Arrays.equals(new int[3], confusionMatrix2[2]));
        assertEquals(OTHER[0], confusionMatrix2[0][2]);
        assertEquals(OTHER[1], confusionMatrix2[1][2]);
        assertEquals(3L, confusionMatrix.getLabels().size());
        assertTrue(confusionMatrix.getLabels().contains(LABELS[0]));
        assertTrue(confusionMatrix.getLabels().contains(LABELS[1]));
        assertTrue(confusionMatrix.getLabels().contains(DEFAULT_LABEL));
    }

    private static void checkAccuracy(ConfusionMatrix confusionMatrix) {
        assertEquals(3L, confusionMatrix.getLabels().size());
        assertEquals(25.0d, confusionMatrix.getAccuracy("Label1"), 1.0E-6d);
        assertEquals(55.5555555d, confusionMatrix.getAccuracy("Label2"), 1.0E-6d);
        assertTrue(Double.isNaN(confusionMatrix.getAccuracy(DEFAULT_LABEL)));
    }

    private static ConfusionMatrix fillConfusionMatrix(int[][] iArr, String[] strArr, String str) {
        ArrayList newArrayList = Lists.newArrayList();
        newArrayList.add(strArr[0]);
        newArrayList.add(strArr[1]);
        ConfusionMatrix confusionMatrix = new ConfusionMatrix(newArrayList, str);
        confusionMatrix.putCount("Label1", "Label1", iArr[0][0]);
        confusionMatrix.putCount("Label1", "Label2", iArr[0][1]);
        confusionMatrix.putCount("Label2", "Label1", iArr[1][0]);
        confusionMatrix.putCount("Label2", "Label2", iArr[1][1]);
        confusionMatrix.putCount("Label1", DEFAULT_LABEL, OTHER[0]);
        confusionMatrix.putCount("Label2", DEFAULT_LABEL, OTHER[1]);
        return confusionMatrix;
    }
}
