package org.apache.spark.mllib.regression;

import java.util.List;
import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.util.LinearDataGenerator;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/apache/spark/mllib/regression/JavaLinearRegressionSuite.class */
public class JavaLinearRegressionSuite extends SharedSparkSession {
    int validatePrediction(List<LabeledPoint> list, LinearRegressionModel linearRegressionModel) {
        int i = 0;
        for (LabeledPoint labeledPoint : list) {
            if (Math.abs(Double.valueOf(linearRegressionModel.predict(labeledPoint.features())).doubleValue() - labeledPoint.label()) <= 0.5d) {
                i++;
            }
        }
        return i;
    }

    @Test
    public void runLinearRegressionUsingConstructor() {
        double[] dArr = {10.0d, 10.0d};
        JavaRDD cache = this.jsc.parallelize(LinearDataGenerator.generateLinearInputAsList(3.0d, dArr, 100, 42, 0.1d), 2).cache();
        List<LabeledPoint> generateLinearInputAsList = LinearDataGenerator.generateLinearInputAsList(3.0d, dArr, 100, 17, 0.1d);
        LinearRegressionWithSGD linearRegressionWithSGD = new LinearRegressionWithSGD();
        linearRegressionWithSGD.setIntercept(true);
        Assert.assertTrue(((double) validatePrediction(generateLinearInputAsList, (LinearRegressionModel) linearRegressionWithSGD.run(cache.rdd()))) > (((double) 100) * 4.0d) / 5.0d);
    }

    @Test
    public void runLinearRegressionUsingStaticMethods() {
        double[] dArr = {10.0d, 10.0d};
        Assert.assertTrue(((double) validatePrediction(LinearDataGenerator.generateLinearInputAsList(0.0d, dArr, 100, 17, 0.1d), LinearRegressionWithSGD.train(this.jsc.parallelize(LinearDataGenerator.generateLinearInputAsList(0.0d, dArr, 100, 42, 0.1d), 2).cache().rdd(), 100))) > (((double) 100) * 4.0d) / 5.0d);
    }

    @Test
    public void testPredictJavaRDD() {
        JavaRDD cache = this.jsc.parallelize(LinearDataGenerator.generateLinearInputAsList(0.0d, new double[]{10.0d, 10.0d}, 100, 42, 0.1d), 2).cache();
        new LinearRegressionWithSGD().run(cache.rdd()).predict(cache.map(new Function<LabeledPoint, Vector>() { // from class: org.apache.spark.mllib.regression.JavaLinearRegressionSuite.1
            public Vector call(LabeledPoint labeledPoint) throws Exception {
                return labeledPoint.features();
            }
        })).first();
    }
}
