/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.classification;

import java.util.HashMap;
import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.classification.GBTClassificationModel;
import org.apache.spark.ml.classification.GBTClassifier;
import org.apache.spark.ml.classification.LogisticRegressionSuite;
import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.ml.tree.impl.TreeTests;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.junit.Test;

public class JavaGBTClassifierSuite
extends SharedSparkSession {
    @Test
    public void runDT() {
        int nPoints = 20;
        double A = 2.0;
        double B = -1.5;
        JavaRDD data = this.jsc.parallelize(LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
        HashMap<Integer, Integer> categoricalFeatures = new HashMap<Integer, Integer>();
        Dataset<Row> dataFrame = TreeTests.setMetadata((JavaRDD<LabeledPoint>)data, categoricalFeatures, 2);
        GBTClassifier rf = new GBTClassifier().setMaxDepth(2).setMaxBins(10).setMinInstancesPerNode(5).setMinInfoGain(0.0).setMaxMemoryInMB(256).setCacheNodeIds(false).setCheckpointInterval(10).setSubsamplingRate(1.0).setSeed(1234L).setMaxIter(3).setStepSize(0.1).setMaxDepth(2);
        for (String lossType : GBTClassifier.supportedLossTypes()) {
            rf.setLossType(lossType);
        }
        GBTClassificationModel model = (GBTClassificationModel)rf.fit(dataFrame);
        model.transform(dataFrame);
        model.totalNumNodes();
        model.toDebugString();
        model.trees();
        model.treeWeights();
    }
}

