package org.apache.spark.mllib.regression;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.spark.api.java.JavaSparkContext;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import scala.Tuple3;

/* loaded from: input_file:org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.class */
public class JavaIsotonicRegressionSuite implements Serializable {
    private transient JavaSparkContext sc;

    private List<Tuple3<Double, Double, Double>> generateIsotonicInput(double[] dArr) {
        ArrayList arrayList = new ArrayList(dArr.length);
        for (int i = 1; i <= dArr.length; i++) {
            arrayList.add(new Tuple3(Double.valueOf(dArr[i - 1]), Double.valueOf(i), Double.valueOf(1.0d)));
        }
        return arrayList;
    }

    private IsotonicRegressionModel runIsotonicRegression(double[] dArr) {
        return new IsotonicRegression().run(this.sc.parallelize(generateIsotonicInput(dArr), 2).cache());
    }

    @Before
    public void setUp() {
        this.sc = new JavaSparkContext("local", "JavaLinearRegressionSuite");
    }

    @After
    public void tearDown() {
        this.sc.stop();
        this.sc = null;
    }

    @Test
    public void testIsotonicRegressionJavaRDD() {
        Assert.assertArrayEquals(new double[]{1.0d, 2.0d, 2.3333333333333335d, 2.3333333333333335d, 6.0d, 7.0d, 8.0d, 10.0d, 10.0d, 12.0d}, runIsotonicRegression(new double[]{1.0d, 2.0d, 3.0d, 3.0d, 1.0d, 6.0d, 7.0d, 8.0d, 11.0d, 9.0d, 10.0d, 12.0d}).predictions(), 1.0E-14d);
    }

    @Test
    public void testIsotonicRegressionPredictionsJavaRDD() {
        List collect = runIsotonicRegression(new double[]{1.0d, 2.0d, 3.0d, 3.0d, 1.0d, 6.0d, 7.0d, 8.0d, 11.0d, 9.0d, 10.0d, 12.0d}).predict(this.sc.parallelizeDoubles(Arrays.asList(Double.valueOf(0.0d), Double.valueOf(1.0d), Double.valueOf(9.5d), Double.valueOf(12.0d), Double.valueOf(13.0d)))).collect();
        Assert.assertTrue(((Double) collect.get(0)).doubleValue() == 1.0d);
        Assert.assertTrue(((Double) collect.get(1)).doubleValue() == 1.0d);
        Assert.assertTrue(((Double) collect.get(2)).doubleValue() == 10.0d);
        Assert.assertTrue(((Double) collect.get(3)).doubleValue() == 12.0d);
        Assert.assertTrue(((Double) collect.get(4)).doubleValue() == 12.0d);
    }
}
