package org.apache.mahout.classifier.naivebayes;

import com.google.common.io.Closeables;
import java.io.File;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Text;
import org.apache.mahout.classifier.naivebayes.training.TrainNaiveBayesJob;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.hadoop.MathHelper;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/classifier/naivebayes/NaiveBayesTest.class */
public class NaiveBayesTest extends MahoutTestCase {
    private Configuration conf;
    private File inputFile;
    private File outputDir;
    private File tempDir;
    static final Text LABEL_STOLEN = new Text("/stolen/");
    static final Text LABEL_NOT_STOLEN = new Text("/not_stolen/");
    static final Vector.Element COLOR_RED = MathHelper.elem(0, 1.0d);
    static final Vector.Element COLOR_YELLOW = MathHelper.elem(1, 1.0d);
    static final Vector.Element TYPE_SPORTS = MathHelper.elem(2, 1.0d);
    static final Vector.Element TYPE_SUV = MathHelper.elem(3, 1.0d);
    static final Vector.Element ORIGIN_DOMESTIC = MathHelper.elem(4, 1.0d);
    static final Vector.Element ORIGIN_IMPORTED = MathHelper.elem(5, 1.0d);

    @Override // org.apache.mahout.common.MahoutTestCase
    @Before
    public void setUp() throws Exception {
        super.setUp();
        this.conf = new Configuration();
        this.inputFile = getTestTempFile("trainingInstances.seq");
        this.outputDir = getTestTempDir("output");
        this.outputDir.delete();
        this.tempDir = getTestTempDir("tmp");
        SequenceFile.Writer writer = new SequenceFile.Writer(FileSystem.get(this.conf), this.conf, new Path(this.inputFile.getAbsolutePath()), Text.class, VectorWritable.class);
        try {
            writer.append(LABEL_STOLEN, trainingInstance(COLOR_RED, TYPE_SPORTS, ORIGIN_DOMESTIC));
            writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_RED, TYPE_SPORTS, ORIGIN_DOMESTIC));
            writer.append(LABEL_STOLEN, trainingInstance(COLOR_RED, TYPE_SPORTS, ORIGIN_DOMESTIC));
            writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_YELLOW, TYPE_SPORTS, ORIGIN_DOMESTIC));
            writer.append(LABEL_STOLEN, trainingInstance(COLOR_YELLOW, TYPE_SPORTS, ORIGIN_IMPORTED));
            writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_YELLOW, TYPE_SUV, ORIGIN_IMPORTED));
            writer.append(LABEL_STOLEN, trainingInstance(COLOR_YELLOW, TYPE_SUV, ORIGIN_IMPORTED));
            writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_YELLOW, TYPE_SUV, ORIGIN_DOMESTIC));
            writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_RED, TYPE_SUV, ORIGIN_IMPORTED));
            writer.append(LABEL_STOLEN, trainingInstance(COLOR_RED, TYPE_SPORTS, ORIGIN_IMPORTED));
            Closeables.closeQuietly(writer);
        } catch (Throwable th) {
            Closeables.closeQuietly(writer);
            throw th;
        }
    }

    @Test
    public void toyData() throws Exception {
        TrainNaiveBayesJob trainNaiveBayesJob = new TrainNaiveBayesJob();
        trainNaiveBayesJob.setConf(this.conf);
        trainNaiveBayesJob.run(new String[]{"--input", this.inputFile.getAbsolutePath(), "--output", this.outputDir.getAbsolutePath(), "-el", "--tempDir", this.tempDir.getAbsolutePath()});
        StandardNaiveBayesClassifier standardNaiveBayesClassifier = new StandardNaiveBayesClassifier(NaiveBayesModel.materialize(new Path(this.outputDir.getAbsolutePath()), this.conf));
        assertEquals(2L, standardNaiveBayesClassifier.numCategories());
        Vector classifyFull = standardNaiveBayesClassifier.classifyFull(trainingInstance(COLOR_RED, TYPE_SUV, ORIGIN_DOMESTIC).get());
        assertTrue(classifyFull.get(0) < classifyFull.get(1));
    }

    @Test
    public void toyDataComplementary() throws Exception {
        TrainNaiveBayesJob trainNaiveBayesJob = new TrainNaiveBayesJob();
        trainNaiveBayesJob.setConf(this.conf);
        trainNaiveBayesJob.run(new String[]{"--input", this.inputFile.getAbsolutePath(), "--output", this.outputDir.getAbsolutePath(), "-el", "--trainComplementary", "--tempDir", this.tempDir.getAbsolutePath()});
        ComplementaryNaiveBayesClassifier complementaryNaiveBayesClassifier = new ComplementaryNaiveBayesClassifier(NaiveBayesModel.materialize(new Path(this.outputDir.getAbsolutePath()), this.conf));
        assertEquals(2L, complementaryNaiveBayesClassifier.numCategories());
        Vector classifyFull = complementaryNaiveBayesClassifier.classifyFull(trainingInstance(COLOR_RED, TYPE_SUV, ORIGIN_DOMESTIC).get());
        assertTrue(classifyFull.get(0) < classifyFull.get(1));
    }

    static VectorWritable trainingInstance(Vector.Element... elementArr) {
        DenseVector denseVector = new DenseVector(6);
        for (Vector.Element element : elementArr) {
            denseVector.set(element.index(), element.get());
        }
        return new VectorWritable(denseVector);
    }
}
