/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.classification;

import java.util.Arrays;
import java.util.List;
import org.apache.spark.SharedSparkSession;
import org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel;
import org.apache.spark.ml.classification.MultilayerPerceptronClassifier;
import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.junit.Assert;
import org.junit.Test;

public class JavaMultilayerPerceptronClassifierSuite
extends SharedSparkSession {
    @Test
    public void testMLPC() {
        List<LabeledPoint> data = Arrays.asList(new LabeledPoint(0.0, Vectors.dense((double)0.0, (double[])new double[]{0.0})), new LabeledPoint(1.0, Vectors.dense((double)0.0, (double[])new double[]{1.0})), new LabeledPoint(1.0, Vectors.dense((double)1.0, (double[])new double[]{0.0})), new LabeledPoint(0.0, Vectors.dense((double)1.0, (double[])new double[]{1.0})));
        Dataset dataFrame = this.spark.createDataFrame(data, LabeledPoint.class);
        MultilayerPerceptronClassifier mlpc = new MultilayerPerceptronClassifier().setLayers(new int[]{2, 5, 2}).setBlockSize(1).setSeed(123L).setMaxIter(100);
        MultilayerPerceptronClassificationModel model = (MultilayerPerceptronClassificationModel)mlpc.fit(dataFrame);
        Dataset result = model.transform(dataFrame);
        List predictionAndLabels = result.select("prediction", new String[]{"label"}).collectAsList();
        for (Row r : predictionAndLabels) {
            Assert.assertEquals((long)((int)r.getDouble(0)), (long)((int)r.getDouble(1)));
        }
    }
}

