package org.apache.spark.ml.util;

import org.apache.spark.SparkFunSuite;
import org.apache.spark.ml.Estimator;
import org.apache.spark.ml.Model;
import org.apache.spark.ml.evaluation.Evaluator;
import org.apache.spark.ml.linalg.Vectors$;
import org.apache.spark.ml.param.ParamMap$;
import org.apache.spark.ml.recommendation.ALS;
import org.apache.spark.ml.recommendation.ALSModel;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.functions$;
import org.apache.spark.sql.types.ByteType$;
import org.apache.spark.sql.types.DecimalType;
import org.apache.spark.sql.types.DoubleType$;
import org.apache.spark.sql.types.FloatType$;
import org.apache.spark.sql.types.IntegerType$;
import org.apache.spark.sql.types.LongType$;
import org.apache.spark.sql.types.NumericType;
import org.apache.spark.sql.types.ShortType$;
import org.apache.spark.sql.types.StringType$;
import org.scalactic.Bool$;
import scala.Function2;
import scala.Function3;
import scala.Predef$;
import scala.StringContext;
import scala.Tuple2;
import scala.Tuple3;
import scala.collection.Iterable;
import scala.collection.Iterable$;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableLike;
import scala.collection.TraversableOnce;
import scala.collection.immutable.List$;
import scala.collection.immutable.Map;
import scala.reflect.ManifestFactory$;
import scala.reflect.api.Mirror;
import scala.reflect.api.TypeCreator;
import scala.reflect.api.Types;
import scala.reflect.api.Universe;
import scala.reflect.runtime.package$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: MLTestingUtils.scala */
/* loaded from: input_file:org/apache/spark/ml/util/MLTestingUtils$.class */
public final class MLTestingUtils$ extends SparkFunSuite {
    public static final MLTestingUtils$ MODULE$ = null;

    static {
        new MLTestingUtils$();
    }

    public void checkCopy(Model<?> model) {
        Model copy = model.copy(ParamMap$.MODULE$.empty());
        String uid = copy.parent().uid();
        String uid2 = model.parent().uid();
        assertionsHelper().macroAssert(Bool$.MODULE$.binaryMacroBool(uid, "==", uid2, uid != null ? uid.equals(uid2) : uid2 == null), "");
        Estimator parent = copy.parent();
        Estimator parent2 = model.parent();
        assertionsHelper().macroAssert(Bool$.MODULE$.binaryMacroBool(parent, "==", parent2, parent != null ? parent.equals(parent2) : parent2 == null), "");
    }

    public <M extends Model<M>, T extends Estimator<M>> void checkNumericTypes(T t, SparkSession sparkSession, boolean z, Function2<M, M, BoxedUnit> function2) {
        Map<NumericType, Dataset<Row>> genClassifDFWithNumericLabelCol = z ? genClassifDFWithNumericLabelCol(sparkSession, genClassifDFWithNumericLabelCol$default$2(), genClassifDFWithNumericLabelCol$default$3()) : genRegressionDFWithNumericLabelCol(sparkSession, genRegressionDFWithNumericLabelCol$default$2(), genRegressionDFWithNumericLabelCol$default$3(), genRegressionDFWithNumericLabelCol$default$4());
        ((Iterable) ((TraversableLike) genClassifDFWithNumericLabelCol.keys().filter(new MLTestingUtils$$anonfun$2())).map(new MLTestingUtils$$anonfun$3(t, genClassifDFWithNumericLabelCol), Iterable$.MODULE$.canBuildFrom())).foreach(new MLTestingUtils$$anonfun$checkNumericTypes$2(function2, t.fit((Dataset) genClassifDFWithNumericLabelCol.apply(DoubleType$.MODULE$))));
        String message = ((IllegalArgumentException) intercept(new MLTestingUtils$$anonfun$4(t, sparkSession.createDataFrame(Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Tuple3[]{new Tuple3("0", Vectors$.MODULE$.dense(0.0d, Predef$.MODULE$.wrapDoubleArray(new double[]{2.0d, 3.0d})), BoxesRunTime.boxToDouble(0.0d))})), package$.MODULE$.universe().TypeTag().apply(package$.MODULE$.universe().runtimeMirror(getClass().getClassLoader()), new TypeCreator() { // from class: org.apache.spark.ml.util.MLTestingUtils$$typecreator1$1
            public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                Universe universe = mirror.universe();
                return universe.internal().reificationSupport().TypeRef(universe.internal().reificationSupport().ThisType(mirror.staticPackage("scala").asModule().moduleClass()), mirror.staticClass("scala.Tuple3"), List$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Types.TypeApi[]{mirror.staticClass("java.lang.String").asType().toTypeConstructor(), mirror.staticClass("org.apache.spark.ml.linalg.Vector").asType().toTypeConstructor(), mirror.staticClass("scala.Double").asType().toTypeConstructor()})));
            }
        })).toDF(Predef$.MODULE$.wrapRefArray(new String[]{"label", "features", "censor"}))), ManifestFactory$.MODULE$.classType(IllegalArgumentException.class))).getMessage();
        assertionsHelper().macroAssert(Bool$.MODULE$.binaryMacroBool(message, "contains", "Column label must be of type NumericType but was actually of type StringType", message.contains("Column label must be of type NumericType but was actually of type StringType")), "");
    }

    public void checkNumericTypesALS(ALS als, SparkSession sparkSession, String str, NumericType numericType, Function2<ALSModel, ALSModel, BoxedUnit> function2, Function3<ALSModel, ALSModel, Dataset<Row>, BoxedUnit> function3) {
        Map<NumericType, Dataset<Row>> genRatingsDFWithNumericCols = genRatingsDFWithNumericCols(sparkSession, str);
        ALSModel fit = als.fit((Dataset) genRatingsDFWithNumericCols.apply(numericType));
        Iterable iterable = (Iterable) ((TraversableLike) genRatingsDFWithNumericCols.keys().filter(new MLTestingUtils$$anonfun$5(numericType))).map(new MLTestingUtils$$anonfun$6(als, genRatingsDFWithNumericCols), Iterable$.MODULE$.canBuildFrom());
        iterable.foreach(new MLTestingUtils$$anonfun$checkNumericTypesALS$1(function2, fit));
        iterable.foreach(new MLTestingUtils$$anonfun$checkNumericTypesALS$2(function3, genRatingsDFWithNumericCols, fit));
        Dataset dataset = (Dataset) genRatingsDFWithNumericCols.apply(numericType);
        String message = ((IllegalArgumentException) intercept(new MLTestingUtils$$anonfun$8(als, dataset.select((Seq) Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col(str).cast(StringType$.MODULE$)})).$plus$plus((Seq) ((TraversableLike) Predef$.MODULE$.refArrayOps(dataset.columns()).toSeq().diff(Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new String[]{str})))).map(new MLTestingUtils$$anonfun$7(), Seq$.MODULE$.canBuildFrom()), Seq$.MODULE$.canBuildFrom()))), ManifestFactory$.MODULE$.classType(IllegalArgumentException.class))).getMessage();
        String s = new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"", " must be of type NumericType but was actually of type StringType"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{str}));
        assertionsHelper().macroAssert(Bool$.MODULE$.binaryMacroBool(message, "contains", s, message.contains(s)), "");
    }

    public <T extends Evaluator> void checkNumericTypes(T t, SparkSession sparkSession) {
        Map<NumericType, Dataset<Row>> genEvaluatorDFWithNumericLabelCol = genEvaluatorDFWithNumericLabelCol(sparkSession, "label", "prediction");
        ((Iterable) ((TraversableLike) genEvaluatorDFWithNumericLabelCol.keys().filter(new MLTestingUtils$$anonfun$9())).map(new MLTestingUtils$$anonfun$10(t, genEvaluatorDFWithNumericLabelCol), Iterable$.MODULE$.canBuildFrom())).foreach(new MLTestingUtils$$anonfun$checkNumericTypes$1(t.evaluate((Dataset) genEvaluatorDFWithNumericLabelCol.apply(DoubleType$.MODULE$))));
        String message = ((IllegalArgumentException) intercept(new MLTestingUtils$$anonfun$1(t, sparkSession.createDataFrame(Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Tuple2[]{new Tuple2("0", BoxesRunTime.boxToDouble(0.0d))})), package$.MODULE$.universe().TypeTag().apply(package$.MODULE$.universe().runtimeMirror(getClass().getClassLoader()), new TypeCreator() { // from class: org.apache.spark.ml.util.MLTestingUtils$$typecreator2$1
            public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                Universe universe = mirror.universe();
                return universe.internal().reificationSupport().TypeRef(universe.internal().reificationSupport().ThisType(mirror.staticPackage("scala").asModule().moduleClass()), mirror.staticClass("scala.Tuple2"), List$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Types.TypeApi[]{mirror.staticClass("java.lang.String").asType().toTypeConstructor(), mirror.staticClass("scala.Double").asType().toTypeConstructor()})));
            }
        })).toDF(Predef$.MODULE$.wrapRefArray(new String[]{"label", "prediction"}))), ManifestFactory$.MODULE$.classType(IllegalArgumentException.class))).getMessage();
        assertionsHelper().macroAssert(Bool$.MODULE$.binaryMacroBool(message, "contains", "Column label must be of type NumericType but was actually of type StringType", message.contains("Column label must be of type NumericType but was actually of type StringType")), "");
    }

    public <M extends Model<M>, T extends Estimator<M>> boolean checkNumericTypes$default$3() {
        return true;
    }

    public Map<NumericType, Dataset<Row>> genClassifDFWithNumericLabelCol(SparkSession sparkSession, String str, String str2) {
        return ((TraversableOnce) Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new NumericType[]{ShortType$.MODULE$, LongType$.MODULE$, IntegerType$.MODULE$, FloatType$.MODULE$, ByteType$.MODULE$, DoubleType$.MODULE$, new DecimalType(10, 0)})).map(new MLTestingUtils$$anonfun$genClassifDFWithNumericLabelCol$1(str, str2, sparkSession.createDataFrame(Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Tuple2[]{new Tuple2(BoxesRunTime.boxToInteger(0), Vectors$.MODULE$.dense(0.0d, Predef$.MODULE$.wrapDoubleArray(new double[]{2.0d, 3.0d}))), new Tuple2(BoxesRunTime.boxToInteger(1), Vectors$.MODULE$.dense(0.0d, Predef$.MODULE$.wrapDoubleArray(new double[]{3.0d, 1.0d}))), new Tuple2(BoxesRunTime.boxToInteger(0), Vectors$.MODULE$.dense(0.0d, Predef$.MODULE$.wrapDoubleArray(new double[]{2.0d, 2.0d}))), new Tuple2(BoxesRunTime.boxToInteger(1), Vectors$.MODULE$.dense(0.0d, Predef$.MODULE$.wrapDoubleArray(new double[]{3.0d, 9.0d}))), new Tuple2(BoxesRunTime.boxToInteger(0), Vectors$.MODULE$.dense(0.0d, Predef$.MODULE$.wrapDoubleArray(new double[]{2.0d, 6.0d})))})), package$.MODULE$.universe().TypeTag().apply(package$.MODULE$.universe().runtimeMirror(getClass().getClassLoader()), new TypeCreator() { // from class: org.apache.spark.ml.util.MLTestingUtils$$typecreator3$1
            public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                Universe universe = mirror.universe();
                return universe.internal().reificationSupport().TypeRef(universe.internal().reificationSupport().ThisType(mirror.staticPackage("scala").asModule().moduleClass()), mirror.staticClass("scala.Tuple2"), List$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Types.TypeApi[]{mirror.staticClass("scala.Int").asType().toTypeConstructor(), mirror.staticClass("org.apache.spark.ml.linalg.Vector").asType().toTypeConstructor()})));
            }
        })).toDF(Predef$.MODULE$.wrapRefArray(new String[]{str, str2}))), Seq$.MODULE$.canBuildFrom())).toMap(Predef$.MODULE$.$conforms());
    }

    public String genClassifDFWithNumericLabelCol$default$2() {
        return "label";
    }

    public String genClassifDFWithNumericLabelCol$default$3() {
        return "features";
    }

    public Map<NumericType, Dataset<Row>> genRegressionDFWithNumericLabelCol(SparkSession sparkSession, String str, String str2, String str3) {
        return ((TraversableOnce) Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new NumericType[]{ShortType$.MODULE$, LongType$.MODULE$, IntegerType$.MODULE$, FloatType$.MODULE$, ByteType$.MODULE$, DoubleType$.MODULE$, new DecimalType(10, 0)})).map(new MLTestingUtils$$anonfun$genRegressionDFWithNumericLabelCol$1(str, str2, str3, sparkSession.createDataFrame(Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Tuple2[]{new Tuple2(BoxesRunTime.boxToInteger(0), Vectors$.MODULE$.dense(0.0d, Predef$.MODULE$.wrapDoubleArray(new double[0]))), new Tuple2(BoxesRunTime.boxToInteger(1), Vectors$.MODULE$.dense(1.0d, Predef$.MODULE$.wrapDoubleArray(new double[0]))), new Tuple2(BoxesRunTime.boxToInteger(2), Vectors$.MODULE$.dense(2.0d, Predef$.MODULE$.wrapDoubleArray(new double[0]))), new Tuple2(BoxesRunTime.boxToInteger(3), Vectors$.MODULE$.dense(3.0d, Predef$.MODULE$.wrapDoubleArray(new double[0]))), new Tuple2(BoxesRunTime.boxToInteger(4), Vectors$.MODULE$.dense(4.0d, Predef$.MODULE$.wrapDoubleArray(new double[0])))})), package$.MODULE$.universe().TypeTag().apply(package$.MODULE$.universe().runtimeMirror(getClass().getClassLoader()), new TypeCreator() { // from class: org.apache.spark.ml.util.MLTestingUtils$$typecreator4$1
            public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                Universe universe = mirror.universe();
                return universe.internal().reificationSupport().TypeRef(universe.internal().reificationSupport().ThisType(mirror.staticPackage("scala").asModule().moduleClass()), mirror.staticClass("scala.Tuple2"), List$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Types.TypeApi[]{mirror.staticClass("scala.Int").asType().toTypeConstructor(), mirror.staticClass("org.apache.spark.ml.linalg.Vector").asType().toTypeConstructor()})));
            }
        })).toDF(Predef$.MODULE$.wrapRefArray(new String[]{str, str2}))), Seq$.MODULE$.canBuildFrom())).toMap(Predef$.MODULE$.$conforms());
    }

    public String genRegressionDFWithNumericLabelCol$default$2() {
        return "label";
    }

    public String genRegressionDFWithNumericLabelCol$default$3() {
        return "features";
    }

    public String genRegressionDFWithNumericLabelCol$default$4() {
        return "censor";
    }

    public Map<NumericType, Dataset<Row>> genRatingsDFWithNumericCols(SparkSession sparkSession, String str) {
        Dataset df = sparkSession.createDataFrame(Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Tuple3[]{new Tuple3(BoxesRunTime.boxToInteger(0), BoxesRunTime.boxToInteger(10), BoxesRunTime.boxToDouble(1.0d)), new Tuple3(BoxesRunTime.boxToInteger(1), BoxesRunTime.boxToInteger(20), BoxesRunTime.boxToDouble(2.0d)), new Tuple3(BoxesRunTime.boxToInteger(2), BoxesRunTime.boxToInteger(30), BoxesRunTime.boxToDouble(3.0d)), new Tuple3(BoxesRunTime.boxToInteger(3), BoxesRunTime.boxToInteger(40), BoxesRunTime.boxToDouble(4.0d)), new Tuple3(BoxesRunTime.boxToInteger(4), BoxesRunTime.boxToInteger(50), BoxesRunTime.boxToDouble(5.0d))})), package$.MODULE$.universe().TypeTag().apply(package$.MODULE$.universe().runtimeMirror(getClass().getClassLoader()), new TypeCreator() { // from class: org.apache.spark.ml.util.MLTestingUtils$$typecreator5$1
            public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                Universe universe = mirror.universe();
                return universe.internal().reificationSupport().TypeRef(universe.internal().reificationSupport().ThisType(mirror.staticPackage("scala").asModule().moduleClass()), mirror.staticClass("scala.Tuple3"), List$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Types.TypeApi[]{mirror.staticClass("scala.Int").asType().toTypeConstructor(), mirror.staticClass("scala.Int").asType().toTypeConstructor(), mirror.staticClass("scala.Double").asType().toTypeConstructor()})));
            }
        })).toDF(Predef$.MODULE$.wrapRefArray(new String[]{"user", "item", "rating"}));
        return ((TraversableOnce) Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new NumericType[]{ShortType$.MODULE$, LongType$.MODULE$, IntegerType$.MODULE$, FloatType$.MODULE$, ByteType$.MODULE$, DoubleType$.MODULE$, new DecimalType(10, 0)})).map(new MLTestingUtils$$anonfun$genRatingsDFWithNumericCols$1(str, df, (Seq) ((TraversableLike) Predef$.MODULE$.refArrayOps(df.columns()).toSeq().diff(Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new String[]{str})))).map(new MLTestingUtils$$anonfun$11(), Seq$.MODULE$.canBuildFrom())), Seq$.MODULE$.canBuildFrom())).toMap(Predef$.MODULE$.$conforms());
    }

    public Map<NumericType, Dataset<Row>> genEvaluatorDFWithNumericLabelCol(SparkSession sparkSession, String str, String str2) {
        return ((TraversableOnce) Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new NumericType[]{ShortType$.MODULE$, LongType$.MODULE$, IntegerType$.MODULE$, FloatType$.MODULE$, ByteType$.MODULE$, DoubleType$.MODULE$, new DecimalType(10, 0)})).map(new MLTestingUtils$$anonfun$genEvaluatorDFWithNumericLabelCol$1(str, str2, sparkSession.createDataFrame(Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Tuple2[]{new Tuple2.mcID.sp(0, 0.0d), new Tuple2.mcID.sp(1, 1.0d), new Tuple2.mcID.sp(2, 2.0d), new Tuple2.mcID.sp(3, 3.0d), new Tuple2.mcID.sp(4, 4.0d)})), package$.MODULE$.universe().TypeTag().apply(package$.MODULE$.universe().runtimeMirror(getClass().getClassLoader()), new TypeCreator() { // from class: org.apache.spark.ml.util.MLTestingUtils$$typecreator6$1
            public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                Universe universe = mirror.universe();
                return universe.internal().reificationSupport().TypeRef(universe.internal().reificationSupport().ThisType(mirror.staticPackage("scala").asModule().moduleClass()), mirror.staticClass("scala.Tuple2"), List$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Types.TypeApi[]{mirror.staticClass("scala.Int").asType().toTypeConstructor(), mirror.staticClass("scala.Double").asType().toTypeConstructor()})));
            }
        })).toDF(Predef$.MODULE$.wrapRefArray(new String[]{str, str2}))), Seq$.MODULE$.canBuildFrom())).toMap(Predef$.MODULE$.$conforms());
    }

    public String genEvaluatorDFWithNumericLabelCol$default$2() {
        return "label";
    }

    public String genEvaluatorDFWithNumericLabelCol$default$3() {
        return "prediction";
    }

    private Object readResolve() {
        return MODULE$;
    }

    private MLTestingUtils$() {
        MODULE$ = this;
    }
}
