package org.apache.spark.ml.classification;

import java.io.Serializable;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import scala.collection.JavaConversions;

/* loaded from: input_file:org/apache/spark/ml/classification/JavaOneVsRestSuite.class */
public class JavaOneVsRestSuite implements Serializable {
    private transient JavaSparkContext jsc;
    private transient SQLContext jsql;
    private transient DataFrame dataset;
    private transient JavaRDD<LabeledPoint> datasetRDD;

    @Before
    public void setUp() {
        this.jsc = new JavaSparkContext("local", "JavaLOneVsRestSuite");
        this.jsql = new SQLContext(this.jsc);
        this.datasetRDD = this.jsc.parallelize(JavaConversions.seqAsJavaList(org.apache.spark.mllib.classification.LogisticRegressionSuite.generateMultinomialLogisticInput(new double[]{-0.57997d, 0.912083d, -0.371077d, -0.819866d, 2.688191d, -0.16624d, -0.84355d, -0.048509d, -0.301789d, 4.170682d}, new double[]{5.843d, 3.057d, 3.758d, 1.199d}, new double[]{0.6856d, 0.1899d, 3.116d, 0.581d}, true, 3, 42)), 2);
        this.dataset = this.jsql.createDataFrame(this.datasetRDD, LabeledPoint.class);
    }

    @After
    public void tearDown() {
        this.jsc.stop();
        this.jsc = null;
    }

    @Test
    public void oneVsRestDefaultParams() {
        OneVsRest oneVsRest = new OneVsRest();
        oneVsRest.setClassifier(new LogisticRegression());
        Assert.assertEquals(oneVsRest.getLabelCol(), "label");
        Assert.assertEquals(oneVsRest.getPredictionCol(), "prediction");
        OneVsRestModel fit = oneVsRest.fit(this.dataset);
        fit.transform(this.dataset).select("label", new String[]{"prediction"}).collectAsList();
        Assert.assertEquals(fit.getLabelCol(), "label");
        Assert.assertEquals(fit.getPredictionCol(), "prediction");
    }
}
