/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.mllib.regression;

import java.util.List;
import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.regression.LassoModel;
import org.apache.spark.mllib.regression.LassoWithSGD;
import org.apache.spark.mllib.util.LinearDataGenerator;
import org.junit.Assert;
import org.junit.Test;

public class JavaLassoSuite
extends SharedSparkSession {
    int validatePrediction(List<LabeledPoint> validationData, LassoModel model) {
        int numAccurate = 0;
        for (LabeledPoint point : validationData) {
            Double prediction = model.predict(point.features());
            if (!(Math.abs(prediction - point.label()) <= 0.5)) continue;
            ++numAccurate;
        }
        return numAccurate;
    }

    @Test
    public void runLassoUsingConstructor() {
        LassoWithSGD lassoSGDImpl;
        LassoModel model;
        int nPoints = 10000;
        double A = 0.0;
        double[] weights = new double[]{-1.5, 0.01};
        JavaRDD testRDD = this.jsc.parallelize(LinearDataGenerator.generateLinearInputAsList((double)A, (double[])weights, (int)nPoints, (int)42, (double)0.1), 2).cache();
        List validationData = LinearDataGenerator.generateLinearInputAsList((double)A, (double[])weights, (int)nPoints, (int)17, (double)0.1);
        int numAccurate = this.validatePrediction(validationData, model = (LassoModel)(lassoSGDImpl = new LassoWithSGD(1.0, 20, 0.01, 1.0)).run(testRDD.rdd()));
        Assert.assertTrue(((double)numAccurate > (double)nPoints * 4.0 / 5.0 ? 1 : 0) != 0);
    }

    @Test
    public void runLassoUsingStaticMethods() {
        LassoModel model;
        int nPoints = 10000;
        double A = 0.0;
        double[] weights = new double[]{-1.5, 0.01};
        JavaRDD testRDD = this.jsc.parallelize(LinearDataGenerator.generateLinearInputAsList((double)A, (double[])weights, (int)nPoints, (int)42, (double)0.1), 2).cache();
        List validationData = LinearDataGenerator.generateLinearInputAsList((double)A, (double[])weights, (int)nPoints, (int)17, (double)0.1);
        int numAccurate = this.validatePrediction(validationData, model = (LassoModel)new LassoWithSGD(1.0, 100, 0.01, 1.0).run(testRDD.rdd()));
        Assert.assertTrue(((double)numAccurate > (double)nPoints * 4.0 / 5.0 ? 1 : 0) != 0);
    }
}

