package org.apache.spark.ml.evaluation;

import org.apache.spark.annotation.Experimental;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.param.ParamPair;
import org.apache.spark.ml.param.ParamValidators$;
import org.apache.spark.ml.param.shared.HasLabelCol;
import org.apache.spark.ml.param.shared.HasPredictionCol;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.ml.util.SchemaUtils$;
import org.apache.spark.mllib.evaluation.RegressionMetrics;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.types.DoubleType$;
import org.apache.spark.sql.types.StructType;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple2;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;

/* compiled from: RegressionEvaluator.scala */
@ScalaSignature(bytes = "\u0006\u0001=4A!\u0001\u0002\u0003\u001b\t\u0019\"+Z4sKN\u001c\u0018n\u001c8Fm\u0006dW/\u0019;pe*\u00111\u0001B\u0001\u000bKZ\fG.^1uS>t'BA\u0003\u0007\u0003\tiGN\u0003\u0002\b\u0011\u0005)1\u000f]1sW*\u0011\u0011BC\u0001\u0007CB\f7\r[3\u000b\u0003-\t1a\u001c:h\u0007\u0001\u0019B\u0001\u0001\b\u00135A\u0011q\u0002E\u0007\u0002\u0005%\u0011\u0011C\u0001\u0002\n\u000bZ\fG.^1u_J\u0004\"a\u0005\r\u000e\u0003QQ!!\u0006\f\u0002\rMD\u0017M]3e\u0015\t9B!A\u0003qCJ\fW.\u0003\u0002\u001a)\t\u0001\u0002*Y:Qe\u0016$\u0017n\u0019;j_:\u001cu\u000e\u001c\t\u0003'mI!\u0001\b\u000b\u0003\u0017!\u000b7\u000fT1cK2\u001cu\u000e\u001c\u0005\t=\u0001\u0011)\u0019!C!?\u0005\u0019Q/\u001b3\u0016\u0003\u0001\u0002\"!I\u0014\u000f\u0005\t*S\"A\u0012\u000b\u0003\u0011\nQa]2bY\u0006L!AJ\u0012\u0002\rA\u0013X\rZ3g\u0013\tA\u0013F\u0001\u0004TiJLgn\u001a\u0006\u0003M\rB\u0001b\u000b\u0001\u0003\u0002\u0003\u0006I\u0001I\u0001\u0005k&$\u0007\u0005C\u0003.\u0001\u0011\u0005a&\u0001\u0004=S:LGO\u0010\u000b\u0003_A\u0002\"a\u0004\u0001\t\u000bya\u0003\u0019\u0001\u0011\t\u000b5\u0002A\u0011\u0001\u001a\u0015\u0003=Bq\u0001\u000e\u0001C\u0002\u0013\u0005Q'\u0001\u0006nKR\u0014\u0018n\u0019(b[\u0016,\u0012A\u000e\t\u0004oa\u0002S\"\u0001\f\n\u0005e2\"!\u0002)be\u0006l\u0007BB\u001e\u0001A\u0003%a'A\u0006nKR\u0014\u0018n\u0019(b[\u0016\u0004\u0003\"B\u001f\u0001\t\u0003y\u0012!D4fi6+GO]5d\u001d\u0006lW\rC\u0003@\u0001\u0011\u0005\u0001)A\u0007tKRlU\r\u001e:jG:\u000bW.\u001a\u000b\u0003\u0003\nk\u0011\u0001\u0001\u0005\u0006\u0007z\u0002\r\u0001I\u0001\u0006m\u0006dW/\u001a\u0005\u0006\u000b\u0002!\tAR\u0001\u0011g\u0016$\bK]3eS\u000e$\u0018n\u001c8D_2$\"!Q$\t\u000b\r#\u0005\u0019\u0001\u0011\t\u000b%\u0003A\u0011\u0001&\u0002\u0017M,G\u000fT1cK2\u001cu\u000e\u001c\u000b\u0003\u0003.CQa\u0011%A\u0002\u0001BQ!\u0014\u0001\u0005B9\u000b\u0001\"\u001a<bYV\fG/\u001a\u000b\u0003\u001fJ\u0003\"A\t)\n\u0005E\u001b#A\u0002#pk\ndW\rC\u0003T\u0019\u0002\u0007A+A\u0004eCR\f7/\u001a;\u0011\u0005UCV\"\u0001,\u000b\u0005]3\u0011aA:rY&\u0011\u0011L\u0016\u0002\n\t\u0006$\u0018M\u0012:b[\u0016DQa\u0017\u0001\u0005Bq\u000ba\"[:MCJ<WM\u001d\"fiR,'/F\u0001^!\t\u0011c,\u0003\u0002`G\t9!i\\8mK\u0006t\u0007\"B1\u0001\t\u0003\u0012\u0017\u0001B2paf$\"aL2\t\u000b\u0011\u0004\u0007\u0019A3\u0002\u000b\u0015DHO]1\u0011\u0005]2\u0017BA4\u0017\u0005!\u0001\u0016M]1n\u001b\u0006\u0004\bF\u0001\u0001j!\tQW.D\u0001l\u0015\tag!\u0001\u0006b]:|G/\u0019;j_:L!A\\6\u0003\u0019\u0015C\b/\u001a:j[\u0016tG/\u00197")
@Experimental
/* loaded from: input_file:org/apache/spark/ml/evaluation/RegressionEvaluator.class */
public final class RegressionEvaluator extends Evaluator implements HasPredictionCol, HasLabelCol {
    private final String uid;
    private final Param<String> metricName;
    private final Param<String> labelCol;
    private final Param<String> predictionCol;

    @Override // org.apache.spark.ml.param.shared.HasLabelCol
    public final Param<String> labelCol() {
        return this.labelCol;
    }

    @Override // org.apache.spark.ml.param.shared.HasLabelCol
    public final void org$apache$spark$ml$param$shared$HasLabelCol$_setter_$labelCol_$eq(Param param) {
        this.labelCol = param;
    }

    @Override // org.apache.spark.ml.param.shared.HasLabelCol
    public final String getLabelCol() {
        return HasLabelCol.Cclass.getLabelCol(this);
    }

    @Override // org.apache.spark.ml.param.shared.HasPredictionCol
    public final Param<String> predictionCol() {
        return this.predictionCol;
    }

    @Override // org.apache.spark.ml.param.shared.HasPredictionCol
    public final void org$apache$spark$ml$param$shared$HasPredictionCol$_setter_$predictionCol_$eq(Param param) {
        this.predictionCol = param;
    }

    @Override // org.apache.spark.ml.param.shared.HasPredictionCol
    public final String getPredictionCol() {
        return HasPredictionCol.Cclass.getPredictionCol(this);
    }

    @Override // org.apache.spark.ml.util.Identifiable
    public String uid() {
        return this.uid;
    }

    public Param<String> metricName() {
        return this.metricName;
    }

    public String getMetricName() {
        return (String) $(metricName());
    }

    public RegressionEvaluator setMetricName(String str) {
        return (RegressionEvaluator) set((Param<Param<String>>) metricName(), (Param<String>) str);
    }

    public RegressionEvaluator setPredictionCol(String str) {
        return (RegressionEvaluator) set((Param<Param<String>>) predictionCol(), (Param<String>) str);
    }

    public RegressionEvaluator setLabelCol(String str) {
        return (RegressionEvaluator) set((Param<Param<String>>) labelCol(), (Param<String>) str);
    }

    @Override // org.apache.spark.ml.evaluation.Evaluator
    public double evaluate(DataFrame dataFrame) {
        double meanAbsoluteError;
        StructType schema = dataFrame.schema();
        SchemaUtils$.MODULE$.checkColumnType(schema, (String) $(predictionCol()), DoubleType$.MODULE$, SchemaUtils$.MODULE$.checkColumnType$default$4());
        SchemaUtils$.MODULE$.checkColumnType(schema, (String) $(labelCol()), DoubleType$.MODULE$, SchemaUtils$.MODULE$.checkColumnType$default$4());
        RegressionMetrics regressionMetrics = new RegressionMetrics((RDD<Tuple2<Object, Object>>) dataFrame.select((String) $(predictionCol()), Predef$.MODULE$.wrapRefArray(new String[]{(String) $(labelCol())})).map(new RegressionEvaluator$$anonfun$1(this), ClassTag$.MODULE$.apply(Tuple2.class)));
        String str = (String) $(metricName());
        if ("rmse" != 0 ? "rmse".equals(str) : str == null) {
            meanAbsoluteError = regressionMetrics.rootMeanSquaredError();
        } else if ("mse" != 0 ? "mse".equals(str) : str == null) {
            meanAbsoluteError = regressionMetrics.meanSquaredError();
        } else if ("r2" != 0 ? "r2".equals(str) : str == null) {
            meanAbsoluteError = regressionMetrics.r2();
        } else {
            if ("mae" != 0 ? !"mae".equals(str) : str != null) {
                throw new MatchError(str);
            }
            meanAbsoluteError = regressionMetrics.meanAbsoluteError();
        }
        return meanAbsoluteError;
    }

    @Override // org.apache.spark.ml.evaluation.Evaluator
    public boolean isLargerBetter() {
        boolean z;
        String str = (String) $(metricName());
        if ("rmse" != 0 ? "rmse".equals(str) : str == null) {
            z = false;
        } else if ("mse" != 0 ? "mse".equals(str) : str == null) {
            z = false;
        } else if ("r2" != 0 ? "r2".equals(str) : str == null) {
            z = true;
        } else {
            if ("mae" != 0 ? !"mae".equals(str) : str != null) {
                throw new MatchError(str);
            }
            z = false;
        }
        return z;
    }

    @Override // org.apache.spark.ml.evaluation.Evaluator, org.apache.spark.ml.param.Params
    public RegressionEvaluator copy(ParamMap paramMap) {
        return (RegressionEvaluator) defaultCopy(paramMap);
    }

    public RegressionEvaluator(String str) {
        this.uid = str;
        HasPredictionCol.Cclass.$init$(this);
        HasLabelCol.Cclass.$init$(this);
        this.metricName = new Param<>(this, "metricName", "metric name in evaluation (mse|rmse|r2|mae)", ParamValidators$.MODULE$.inArray(new String[]{"mse", "rmse", "r2", "mae"}));
        setDefault(Predef$.MODULE$.wrapRefArray(new ParamPair[]{metricName().$minus$greater("rmse")}));
    }

    public RegressionEvaluator() {
        this(Identifiable$.MODULE$.randomUID("regEval"));
    }
}
