package org.apache.spark.ml.classification;

import java.util.HashMap;
import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.ml.tree.impl.TreeTests;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.junit.Test;

/* loaded from: input_file:org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.class */
public class JavaDecisionTreeClassifierSuite extends SharedSparkSession {
    @Test
    public void runDT() {
        Dataset<Row> metadata = TreeTests.setMetadata((JavaRDD<LabeledPoint>) this.jsc.parallelize(LogisticRegressionSuite.generateLogisticInputAsList(2.0d, -1.5d, 20, 42), 2).cache(), new HashMap(), 2);
        DecisionTreeClassifier maxDepth = new DecisionTreeClassifier().setMaxDepth(2).setMaxBins(10).setMinInstancesPerNode(5).setMinInfoGain(0.0d).setMaxMemoryInMB(256).setCacheNodeIds(false).setCheckpointInterval(10).setMaxDepth(2);
        for (String str : DecisionTreeClassifier.supportedImpurities()) {
            maxDepth.setImpurity(str);
        }
        DecisionTreeClassificationModel fit = maxDepth.fit(metadata);
        fit.transform(metadata);
        fit.numNodes();
        fit.depth();
        fit.toDebugString();
    }
}
