package org.apache.spark.ml.regression;

import org.apache.spark.SparkException;
import org.apache.spark.ml.feature.Instance;
import org.apache.spark.ml.feature.OffsetInstance;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.optim.IterativelyReweightedLeastSquares;
import org.apache.spark.ml.optim.IterativelyReweightedLeastSquaresModel;
import org.apache.spark.ml.optim.WeightedLeastSquares;
import org.apache.spark.ml.optim.WeightedLeastSquares$;
import org.apache.spark.ml.optim.WeightedLeastSquaresModel;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.regression.GeneralizedLinearRegression;
import org.apache.spark.ml.util.Instrumentation;
import org.apache.spark.ml.util.OptionalInstrumentation$;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.functions$;
import org.apache.spark.sql.types.DoubleType$;
import scala.Predef$;
import scala.Serializable;
import scala.Some;
import scala.StringContext;
import scala.collection.mutable.StringBuilder;
import scala.reflect.ClassTag$;
import scala.runtime.AbstractFunction1;
import scala.runtime.BoxesRunTime;

/* compiled from: GeneralizedLinearRegression.scala */
/* loaded from: input_file:org/apache/spark/ml/regression/GeneralizedLinearRegression$$anonfun$train$1.class */
public final class GeneralizedLinearRegression$$anonfun$train$1 extends AbstractFunction1<Instrumentation, GeneralizedLinearRegressionModel> implements Serializable {
    public static final long serialVersionUID = 0;
    private final /* synthetic */ GeneralizedLinearRegression $outer;
    private final Dataset dataset$1;

    public final GeneralizedLinearRegressionModel apply(Instrumentation instrumentation) {
        GeneralizedLinearRegressionModel summary;
        GeneralizedLinearRegression.FamilyAndLink apply = GeneralizedLinearRegression$FamilyAndLink$.MODULE$.apply(this.$outer);
        int size = ((Vector) ((Row) this.dataset$1.select(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col((String) this.$outer.$(this.$outer.featuresCol()))})).first()).getAs(0)).size();
        instrumentation.logPipelineStage(this.$outer);
        instrumentation.logDataset(this.dataset$1);
        instrumentation.logParams(this.$outer, Predef$.MODULE$.wrapRefArray(new Param[]{this.$outer.labelCol(), this.$outer.featuresCol(), this.$outer.weightCol(), this.$outer.offsetCol(), this.$outer.predictionCol(), this.$outer.linkPredictionCol(), this.$outer.family(), this.$outer.solver(), this.$outer.fitIntercept(), this.$outer.link(), this.$outer.maxIter(), this.$outer.regParam(), this.$outer.tol()}));
        instrumentation.logNumFeatures(size);
        if (size > WeightedLeastSquares$.MODULE$.MAX_NUM_FEATURES()) {
            throw new SparkException(new StringBuilder().append("Currently, GeneralizedLinearRegression only supports number of features").append(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{" <= ", ". Found ", " in the input dataset."})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToInteger(WeightedLeastSquares$.MODULE$.MAX_NUM_FEATURES()), BoxesRunTime.boxToInteger(size)}))).toString());
        }
        Predef$.MODULE$.require(size > 0 || BoxesRunTime.unboxToBoolean(this.$outer.$(this.$outer.fitIntercept())), new GeneralizedLinearRegression$$anonfun$train$1$$anonfun$apply$1(this));
        Column col = this.$outer.hasWeightCol() ? functions$.MODULE$.col((String) this.$outer.$(this.$outer.weightCol())) : functions$.MODULE$.lit(BoxesRunTime.boxToDouble(1.0d));
        Column cast = this.$outer.hasOffsetCol() ? functions$.MODULE$.col((String) this.$outer.$(this.$outer.offsetCol())).cast(DoubleType$.MODULE$) : functions$.MODULE$.lit(BoxesRunTime.boxToDouble(0.0d));
        GeneralizedLinearRegression.Family family = apply.family();
        GeneralizedLinearRegression$Gaussian$ generalizedLinearRegression$Gaussian$ = GeneralizedLinearRegression$Gaussian$.MODULE$;
        if (family != null ? family.equals(generalizedLinearRegression$Gaussian$) : generalizedLinearRegression$Gaussian$ == null) {
            GeneralizedLinearRegression.Link link = apply.link();
            GeneralizedLinearRegression$Identity$ generalizedLinearRegression$Identity$ = GeneralizedLinearRegression$Identity$.MODULE$;
            if (link != null ? link.equals(generalizedLinearRegression$Identity$) : generalizedLinearRegression$Identity$ == null) {
                WeightedLeastSquaresModel fit = new WeightedLeastSquares(BoxesRunTime.unboxToBoolean(this.$outer.$(this.$outer.fitIntercept())), BoxesRunTime.unboxToDouble(this.$outer.$(this.$outer.regParam())), 0.0d, true, true, WeightedLeastSquares$.MODULE$.$lessinit$greater$default$6(), WeightedLeastSquares$.MODULE$.$lessinit$greater$default$7(), WeightedLeastSquares$.MODULE$.$lessinit$greater$default$8()).fit(this.dataset$1.select(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col((String) this.$outer.$(this.$outer.labelCol())), col, cast, functions$.MODULE$.col((String) this.$outer.$(this.$outer.featuresCol()))})).rdd().map(new GeneralizedLinearRegression$$anonfun$train$1$$anonfun$6(this), ClassTag$.MODULE$.apply(Instance.class)), OptionalInstrumentation$.MODULE$.create(instrumentation));
                GeneralizedLinearRegressionModel generalizedLinearRegressionModel = (GeneralizedLinearRegressionModel) this.$outer.copyValues(new GeneralizedLinearRegressionModel(this.$outer.uid(), fit.coefficients(), fit.intercept()).setParent(this.$outer), this.$outer.copyValues$default$2());
                summary = generalizedLinearRegressionModel.setSummary(new Some(new GeneralizedLinearRegressionTrainingSummary(this.dataset$1, generalizedLinearRegressionModel, fit.diagInvAtWA().toArray(), 1, this.$outer.getSolver())));
                return summary;
            }
        }
        RDD<OffsetInstance> map = this.dataset$1.select(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col((String) this.$outer.$(this.$outer.labelCol())), col, cast, functions$.MODULE$.col((String) this.$outer.$(this.$outer.featuresCol()))})).rdd().map(new GeneralizedLinearRegression$$anonfun$train$1$$anonfun$7(this), ClassTag$.MODULE$.apply(OffsetInstance.class));
        IterativelyReweightedLeastSquaresModel fit2 = new IterativelyReweightedLeastSquares(apply.initialize(map, BoxesRunTime.unboxToBoolean(this.$outer.$(this.$outer.fitIntercept())), BoxesRunTime.unboxToDouble(this.$outer.$(this.$outer.regParam())), OptionalInstrumentation$.MODULE$.create(instrumentation)), new GeneralizedLinearRegression$$anonfun$train$1$$anonfun$8(this, apply), BoxesRunTime.unboxToBoolean(this.$outer.$(this.$outer.fitIntercept())), BoxesRunTime.unboxToDouble(this.$outer.$(this.$outer.regParam())), BoxesRunTime.unboxToInt(this.$outer.$(this.$outer.maxIter())), BoxesRunTime.unboxToDouble(this.$outer.$(this.$outer.tol()))).fit(map, OptionalInstrumentation$.MODULE$.create(instrumentation));
        GeneralizedLinearRegressionModel generalizedLinearRegressionModel2 = (GeneralizedLinearRegressionModel) this.$outer.copyValues(new GeneralizedLinearRegressionModel(this.$outer.uid(), fit2.coefficients(), fit2.intercept()).setParent(this.$outer), this.$outer.copyValues$default$2());
        summary = generalizedLinearRegressionModel2.setSummary(new Some(new GeneralizedLinearRegressionTrainingSummary(this.dataset$1, generalizedLinearRegressionModel2, fit2.diagInvAtWA().toArray(), fit2.numIterations(), this.$outer.getSolver())));
        return summary;
    }

    public GeneralizedLinearRegression$$anonfun$train$1(GeneralizedLinearRegression generalizedLinearRegression, Dataset dataset) {
        if (generalizedLinearRegression == null) {
            throw null;
        }
        this.$outer = generalizedLinearRegression;
        this.dataset$1 = dataset;
    }
}
