package org.apache.spark.ml.regression;

import java.io.Serializable;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.param.ParamPair;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/apache/spark/ml/regression/JavaLinearRegressionSuite.class */
public class JavaLinearRegressionSuite implements Serializable {
    private transient JavaSparkContext jsc;
    private transient SQLContext jsql;
    private transient DataFrame dataset;
    private transient JavaRDD<LabeledPoint> datasetRDD;

    @Before
    public void setUp() {
        this.jsc = new JavaSparkContext("local", "JavaLinearRegressionSuite");
        this.jsql = new SQLContext(this.jsc);
        this.datasetRDD = this.jsc.parallelize(LogisticRegressionSuite.generateLogisticInputAsList(1.0d, 1.0d, 100, 42), 2);
        this.dataset = this.jsql.createDataFrame(this.datasetRDD, LabeledPoint.class);
        this.dataset.registerTempTable("dataset");
    }

    @After
    public void tearDown() {
        this.jsc.stop();
        this.jsc = null;
    }

    @Test
    public void linearRegressionDefaultParams() {
        LinearRegression linearRegression = new LinearRegression();
        Assert.assertEquals("label", linearRegression.getLabelCol());
        Assert.assertEquals("auto", linearRegression.getSolver());
        LinearRegressionModel fit = linearRegression.fit(this.dataset);
        fit.transform(this.dataset).registerTempTable("prediction");
        this.jsql.sql("SELECT label, prediction FROM prediction").collect();
        Assert.assertEquals("features", fit.getFeaturesCol());
        Assert.assertEquals("prediction", fit.getPredictionCol());
    }

    @Test
    public void linearRegressionWithSetters() {
        LinearRegression solver = new LinearRegression().setMaxIter(10).setRegParam(1.0d).setSolver("l-bfgs");
        LinearRegression parent = solver.fit(this.dataset).parent();
        Assert.assertEquals(10L, parent.getMaxIter());
        Assert.assertEquals(1.0d, parent.getRegParam(), 0.0d);
        LinearRegressionModel fit = solver.fit(this.dataset, solver.maxIter().w(5), new ParamPair[]{solver.regParam().w(0.1d), solver.predictionCol().w("thePred")});
        LinearRegression parent2 = fit.parent();
        Assert.assertEquals(5L, parent2.getMaxIter());
        Assert.assertEquals(0.1d, parent2.getRegParam(), 0.0d);
        Assert.assertEquals("thePred", fit.getPredictionCol());
    }
}
