package org.apache.spark.ml.classification;

import java.util.Arrays;
import org.apache.spark.SharedSparkSession;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/apache/spark/ml/classification/JavaNaiveBayesSuite.class */
public class JavaNaiveBayesSuite extends SharedSparkSession {
    public void validatePrediction(Dataset<Row> dataset) {
        for (Row row : dataset.collectAsList()) {
            Assert.assertEquals(((Double) row.getAs(1)).doubleValue(), ((Double) row.getAs(0)).doubleValue(), 1.0E-5d);
        }
    }

    @Test
    public void naiveBayesDefaultParams() {
        NaiveBayes naiveBayes = new NaiveBayes();
        Assert.assertEquals("label", naiveBayes.getLabelCol());
        Assert.assertEquals("features", naiveBayes.getFeaturesCol());
        Assert.assertEquals("prediction", naiveBayes.getPredictionCol());
        Assert.assertEquals(1.0d, naiveBayes.getSmoothing(), 1.0E-5d);
        Assert.assertEquals("multinomial", naiveBayes.getModelType());
    }

    @Test
    public void testNaiveBayes() {
        Dataset createDataFrame = this.spark.createDataFrame(Arrays.asList(RowFactory.create(new Object[]{Double.valueOf(0.0d), Vectors.dense(1.0d, new double[]{0.0d, 0.0d})}), RowFactory.create(new Object[]{Double.valueOf(0.0d), Vectors.dense(2.0d, new double[]{0.0d, 0.0d})}), RowFactory.create(new Object[]{Double.valueOf(1.0d), Vectors.dense(0.0d, new double[]{1.0d, 0.0d})}), RowFactory.create(new Object[]{Double.valueOf(1.0d), Vectors.dense(0.0d, new double[]{2.0d, 0.0d})}), RowFactory.create(new Object[]{Double.valueOf(2.0d), Vectors.dense(0.0d, new double[]{0.0d, 1.0d})}), RowFactory.create(new Object[]{Double.valueOf(2.0d), Vectors.dense(0.0d, new double[]{0.0d, 2.0d})})), new StructType(new StructField[]{new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), new StructField("features", new VectorUDT(), false, Metadata.empty())}));
        validatePrediction(new NaiveBayes().setSmoothing(0.5d).setModelType("multinomial").fit(createDataFrame).transform(createDataFrame).select("prediction", new String[]{"label"}));
    }
}
