package org.apache.spark.ml.tuning;

import org.apache.spark.ml.Estimator;
import org.apache.spark.ml.Model;
import org.apache.spark.ml.evaluation.Evaluator;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.util.Instrumentation;
import org.apache.spark.mllib.util.MLUtils$;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.StructType;
import scala.Array$;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Serializable;
import scala.Some;
import scala.Tuple2;
import scala.concurrent.ExecutionContext;
import scala.math.Ordering$Double$;
import scala.reflect.ClassTag$;
import scala.runtime.AbstractFunction1;
import scala.runtime.BoxesRunTime;
import scala.runtime.ObjectRef;
import scala.runtime.ScalaRunTime$;

/* compiled from: CrossValidator.scala */
/* loaded from: input_file:org/apache/spark/ml/tuning/CrossValidator$$anonfun$fit$1.class */
public final class CrossValidator$$anonfun$fit$1 extends AbstractFunction1<Instrumentation, CrossValidatorModel> implements Serializable {
    public static final long serialVersionUID = 0;
    private final /* synthetic */ CrossValidator $outer;
    private final Dataset dataset$1;

    public final CrossValidatorModel apply(Instrumentation instrumentation) {
        StructType schema = this.dataset$1.schema();
        this.$outer.transformSchema(schema, true);
        SparkSession sparkSession = this.dataset$1.sparkSession();
        Estimator estimator = (Estimator) this.$outer.$(this.$outer.estimator());
        Evaluator evaluator = (Evaluator) this.$outer.$(this.$outer.evaluator());
        ParamMap[] paramMapArr = (ParamMap[]) this.$outer.$(this.$outer.estimatorParamMaps());
        ExecutionContext executionContext = this.$outer.getExecutionContext();
        instrumentation.logPipelineStage(this.$outer);
        instrumentation.logDataset(this.dataset$1);
        instrumentation.logParams(this.$outer, Predef$.MODULE$.wrapRefArray(new Param[]{this.$outer.numFolds(), this.$outer.seed(), this.$outer.parallelism()}));
        this.$outer.logTuningParams(instrumentation);
        boolean unboxToBoolean = BoxesRunTime.unboxToBoolean(this.$outer.$(this.$outer.collectSubModels()));
        ObjectRef create = ObjectRef.create(unboxToBoolean ? new Some(Array$.MODULE$.fill(BoxesRunTime.unboxToInt(this.$outer.$(this.$outer.numFolds())), new CrossValidator$$anonfun$fit$1$$anonfun$3(this, paramMapArr), ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Model.class)))) : None$.MODULE$);
        double[] dArr = (double[]) Predef$.MODULE$.refArrayOps(Predef$.MODULE$.refArrayOps((Object[]) Predef$.MODULE$.refArrayOps((Object[]) Predef$.MODULE$.refArrayOps(MLUtils$.MODULE$.kFold(this.dataset$1.toDF().rdd(), BoxesRunTime.unboxToInt(this.$outer.$(this.$outer.numFolds())), BoxesRunTime.unboxToLong(this.$outer.$(this.$outer.seed())), ClassTag$.MODULE$.apply(Row.class))).zipWithIndex(Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)))).map(new CrossValidator$$anonfun$fit$1$$anonfun$4(this, schema, sparkSession, estimator, evaluator, paramMapArr, executionContext, unboxToBoolean, create, instrumentation), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Double.TYPE))))).transpose(Predef$.MODULE$.$conforms())).map(new CrossValidator$$anonfun$fit$1$$anonfun$7(this), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Double()));
        instrumentation.logInfo(new CrossValidator$$anonfun$fit$1$$anonfun$apply$4(this, dArr));
        Tuple2 tuple2 = evaluator.isLargerBetter() ? (Tuple2) Predef$.MODULE$.refArrayOps((Object[]) Predef$.MODULE$.doubleArrayOps(dArr).zipWithIndex(Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)))).maxBy(new CrossValidator$$anonfun$fit$1$$anonfun$8(this), Ordering$Double$.MODULE$) : (Tuple2) Predef$.MODULE$.refArrayOps((Object[]) Predef$.MODULE$.doubleArrayOps(dArr).zipWithIndex(Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)))).minBy(new CrossValidator$$anonfun$fit$1$$anonfun$9(this), Ordering$Double$.MODULE$);
        if (tuple2 == null) {
            throw new MatchError(tuple2);
        }
        Tuple2.mcDI.sp spVar = new Tuple2.mcDI.sp(tuple2._1$mcD$sp(), tuple2._2$mcI$sp());
        double _1$mcD$sp = spVar._1$mcD$sp();
        int _2$mcI$sp = spVar._2$mcI$sp();
        instrumentation.logInfo(new CrossValidator$$anonfun$fit$1$$anonfun$apply$5(this, paramMapArr, _2$mcI$sp));
        instrumentation.logInfo(new CrossValidator$$anonfun$fit$1$$anonfun$apply$6(this, _1$mcD$sp));
        return (CrossValidatorModel) this.$outer.copyValues(new CrossValidatorModel(this.$outer.uid(), (Model<?>) estimator.fit(this.dataset$1, paramMapArr[_2$mcI$sp]), dArr).setSubModels((Option<Model<?>[][]>) create.elem).setParent(this.$outer), this.$outer.copyValues$default$2());
    }

    public /* synthetic */ CrossValidator org$apache$spark$ml$tuning$CrossValidator$$anonfun$$$outer() {
        return this.$outer;
    }

    public CrossValidator$$anonfun$fit$1(CrossValidator crossValidator, Dataset dataset) {
        if (crossValidator == null) {
            throw null;
        }
        this.$outer = crossValidator;
        this.dataset$1 = dataset;
    }
}
