/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml;

import java.io.IOException;
import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.LogisticRegressionSuite;
import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.ml.feature.StandardScaler;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.junit.Test;

public class JavaPipelineSuite
extends SharedSparkSession {
    private transient Dataset<Row> dataset;

    @Override
    public void setUp() throws IOException {
        super.setUp();
        JavaRDD points = this.jsc.parallelize(LogisticRegressionSuite.generateLogisticInputAsList(1.0, 1.0, 100, 42), 2);
        this.dataset = this.spark.createDataFrame(points, LabeledPoint.class);
    }

    @Test
    public void pipeline() {
        StandardScaler scaler = new StandardScaler().setInputCol("features").setOutputCol("scaledFeatures");
        LogisticRegression lr = (LogisticRegression)new LogisticRegression().setFeaturesCol("scaledFeatures");
        Pipeline pipeline = new Pipeline().setStages(new PipelineStage[]{scaler, lr});
        PipelineModel model = pipeline.fit(this.dataset);
        model.transform(this.dataset).createOrReplaceTempView("prediction");
        Dataset predictions = this.spark.sql("SELECT label, probability, prediction FROM prediction");
        predictions.collectAsList();
    }
}

