package org.apache.spark.ml;

import org.apache.spark.annotation.DeveloperApi;
import org.apache.spark.ml.PredictionModel;
import org.apache.spark.ml.Predictor;
import org.apache.spark.ml.PredictorParams;
import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.param.shared.HasFeaturesCol;
import org.apache.spark.ml.param.shared.HasLabelCol;
import org.apache.spark.ml.param.shared.HasPredictionCol;
import org.apache.spark.ml.param.shared.HasWeightCol;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.functions$;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DoubleType$;
import org.apache.spark.sql.types.StructType;
import scala.Predef$;
import scala.collection.immutable.StringOps;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;

/* compiled from: Predictor.scala */
@DeveloperApi
@ScalaSignature(bytes = "\u0006\u0001\u0005-b!B\u0001\u0003\u0003\u0003Y!!\u0003)sK\u0012L7\r^8s\u0015\t\u0019A!\u0001\u0002nY*\u0011QAB\u0001\u0006gB\f'o\u001b\u0006\u0003\u000f!\ta!\u00199bG\",'\"A\u0005\u0002\u0007=\u0014xm\u0001\u0001\u0016\t1\u0001cfE\n\u0004\u000151\u0003c\u0001\b\u0010#5\t!!\u0003\u0002\u0011\u0005\tIQi\u001d;j[\u0006$xN\u001d\t\u0003%Ma\u0001\u0001B\u0003\u0015\u0001\t\u0007QCA\u0001N#\t1B\u0004\u0005\u0002\u001855\t\u0001DC\u0001\u001a\u0003\u0015\u00198-\u00197b\u0013\tY\u0002DA\u0004O_RD\u0017N\\4\u0011\t9ir$E\u0005\u0003=\t\u0011q\u0002\u0015:fI&\u001cG/[8o\u001b>$W\r\u001c\t\u0003%\u0001\"Q!\t\u0001C\u0002\t\u0012ABR3biV\u0014Xm\u001d+za\u0016\f\"AF\u0012\u0011\u0005]!\u0013BA\u0013\u0019\u0005\r\te.\u001f\t\u0003\u001d\u001dJ!\u0001\u000b\u0002\u0003\u001fA\u0013X\rZ5di>\u0014\b+\u0019:b[NDQA\u000b\u0001\u0005\u0002-\na\u0001P5oSRtD#\u0001\u0017\u0011\u000b9\u0001q$L\t\u0011\u0005IqC!B\u0018\u0001\u0005\u0004\u0001$a\u0002'fCJtWM]\t\u0003-1BQA\r\u0001\u0005\u0002M\n1b]3u\u0019\u0006\u0014W\r\\\"pYR\u0011Q\u0006\u000e\u0005\u0006kE\u0002\rAN\u0001\u0006m\u0006dW/\u001a\t\u0003oir!a\u0006\u001d\n\u0005eB\u0012A\u0002)sK\u0012,g-\u0003\u0002<y\t11\u000b\u001e:j]\u001eT!!\u000f\r\t\u000by\u0002A\u0011A \u0002\u001dM,GOR3biV\u0014Xm]\"pYR\u0011Q\u0006\u0011\u0005\u0006ku\u0002\rA\u000e\u0005\u0006\u0005\u0002!\taQ\u0001\u0011g\u0016$\bK]3eS\u000e$\u0018n\u001c8D_2$\"!\f#\t\u000bU\n\u0005\u0019\u0001\u001c\t\u000b\u0019\u0003A\u0011I$\u0002\u0007\u0019LG\u000f\u0006\u0002\u0012\u0011\")\u0011*\u0012a\u0001\u0015\u00069A-\u0019;bg\u0016$\bGA&S!\rau*U\u0007\u0002\u001b*\u0011a\nB\u0001\u0004gFd\u0017B\u0001)N\u0005\u001d!\u0015\r^1tKR\u0004\"A\u0005*\u0005\u0013MC\u0015\u0011!A\u0001\u0006\u0003\u0011#aA0%c!)Q\u000b\u0001D!-\u0006!1m\u001c9z)\tis\u000bC\u0003Y)\u0002\u0007\u0011,A\u0003fqR\u0014\u0018\r\u0005\u0002[;6\t1L\u0003\u0002]\u0005\u0005)\u0001/\u0019:b[&\u0011al\u0017\u0002\t!\u0006\u0014\u0018-\\'ba\")\u0001\r\u0001D\tC\u0006)AO]1j]R\u0011\u0011C\u0019\u0005\u0006\u0013~\u0003\ra\u0019\u0019\u0003I\u001a\u00042\u0001T(f!\t\u0011b\rB\u0005hE\u0006\u0005\t\u0011!B\u0001E\t\u0019q\f\n\u001a\t\r%\u0004A\u0011\u0001\u0002k\u0003A1W-\u0019;ve\u0016\u001cH)\u0019;b)f\u0004X-F\u0001l!\taw.D\u0001n\u0015\tqW*A\u0003usB,7/\u0003\u0002q[\nAA)\u0019;b)f\u0004X\rC\u0003s\u0001\u0011\u00053/A\bue\u0006t7OZ8s[N\u001b\u0007.Z7b)\t!x\u000f\u0005\u0002mk&\u0011a/\u001c\u0002\u000b'R\u0014Xo\u0019;UsB,\u0007\"\u0002=r\u0001\u0004!\u0018AB:dQ\u0016l\u0017\rC\u0003{\u0001\u0011E10\u0001\u000bfqR\u0014\u0018m\u0019;MC\n,G.\u001a3Q_&tGo\u001d\u000b\u0004y\u0006E\u0001#B?\u0002\u0002\u0005\u0015Q\"\u0001@\u000b\u0005}$\u0011a\u0001:eI&\u0019\u00111\u0001@\u0003\u0007I#E\t\u0005\u0003\u0002\b\u00055QBAA\u0005\u0015\r\tYAA\u0001\bM\u0016\fG/\u001e:f\u0013\u0011\ty!!\u0003\u0003\u00191\u000b'-\u001a7fIB{\u0017N\u001c;\t\r%K\b\u0019AA\na\u0011\t)\"!\u0007\u0011\t1{\u0015q\u0003\t\u0004%\u0005eAaCA\u000e\u0003#\t\t\u0011!A\u0003\u0002\t\u00121a\u0018\u00134Q\r\u0001\u0011q\u0004\t\u0005\u0003C\t9#\u0004\u0002\u0002$)\u0019\u0011Q\u0005\u0003\u0002\u0015\u0005tgn\u001c;bi&|g.\u0003\u0003\u0002*\u0005\r\"\u0001\u0004#fm\u0016dw\u000e]3s\u0003BL\u0007")
/* loaded from: input_file:org/apache/spark/ml/Predictor.class */
public abstract class Predictor<FeaturesType, Learner extends Predictor<FeaturesType, Learner, M>, M extends PredictionModel<FeaturesType, M>> extends Estimator<M> implements PredictorParams {
    private final Param<String> predictionCol;
    private final Param<String> featuresCol;
    private final Param<String> labelCol;

    @Override // org.apache.spark.ml.PredictorParams
    public StructType validateAndTransformSchema(StructType structType, boolean z, DataType dataType) {
        return PredictorParams.Cclass.validateAndTransformSchema(this, structType, z, dataType);
    }

    @Override // org.apache.spark.ml.param.shared.HasPredictionCol
    public final Param<String> predictionCol() {
        return this.predictionCol;
    }

    @Override // org.apache.spark.ml.param.shared.HasPredictionCol
    public final void org$apache$spark$ml$param$shared$HasPredictionCol$_setter_$predictionCol_$eq(Param param) {
        this.predictionCol = param;
    }

    @Override // org.apache.spark.ml.param.shared.HasPredictionCol
    public final String getPredictionCol() {
        return HasPredictionCol.Cclass.getPredictionCol(this);
    }

    @Override // org.apache.spark.ml.param.shared.HasFeaturesCol
    public final Param<String> featuresCol() {
        return this.featuresCol;
    }

    @Override // org.apache.spark.ml.param.shared.HasFeaturesCol
    public final void org$apache$spark$ml$param$shared$HasFeaturesCol$_setter_$featuresCol_$eq(Param param) {
        this.featuresCol = param;
    }

    @Override // org.apache.spark.ml.param.shared.HasFeaturesCol
    public final String getFeaturesCol() {
        return HasFeaturesCol.Cclass.getFeaturesCol(this);
    }

    @Override // org.apache.spark.ml.param.shared.HasLabelCol
    public final Param<String> labelCol() {
        return this.labelCol;
    }

    @Override // org.apache.spark.ml.param.shared.HasLabelCol
    public final void org$apache$spark$ml$param$shared$HasLabelCol$_setter_$labelCol_$eq(Param param) {
        this.labelCol = param;
    }

    @Override // org.apache.spark.ml.param.shared.HasLabelCol
    public final String getLabelCol() {
        return HasLabelCol.Cclass.getLabelCol(this);
    }

    public Learner setLabelCol(String str) {
        return (Learner) set((Param<Param>) labelCol(), (Param) str);
    }

    public Learner setFeaturesCol(String str) {
        return (Learner) set((Param<Param>) featuresCol(), (Param) str);
    }

    public Learner setPredictionCol(String str) {
        return (Learner) set((Param<Param>) predictionCol(), (Param) str);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.apache.spark.ml.Estimator
    public M fit(Dataset<?> dataset) {
        Dataset dataset2;
        Dataset dataset3;
        transformSchema(dataset.schema(), true);
        Dataset withColumn = dataset.withColumn((String) $(labelCol()), functions$.MODULE$.col((String) $(labelCol())).cast(DoubleType$.MODULE$), dataset.schema().apply((String) $(labelCol())).metadata());
        if (this instanceof HasWeightCol) {
            if (isDefined(((HasWeightCol) this).weightCol()) && new StringOps(Predef$.MODULE$.augmentString((String) $(((HasWeightCol) this).weightCol()))).nonEmpty()) {
                dataset3 = withColumn.withColumn((String) $(((HasWeightCol) this).weightCol()), functions$.MODULE$.col((String) $(((HasWeightCol) this).weightCol())).cast(DoubleType$.MODULE$), dataset.schema().apply((String) $(((HasWeightCol) this).weightCol())).metadata());
            } else {
                dataset3 = withColumn;
            }
            dataset2 = dataset3;
        } else {
            dataset2 = withColumn;
        }
        return (M) copyValues(train(dataset2).setParent(this), copyValues$default$2());
    }

    @Override // org.apache.spark.ml.Estimator, org.apache.spark.ml.PipelineStage, org.apache.spark.ml.param.Params
    public abstract Learner copy(ParamMap paramMap);

    public abstract M train(Dataset<?> dataset);

    public DataType featuresDataType() {
        return new VectorUDT();
    }

    @Override // org.apache.spark.ml.PipelineStage
    public StructType transformSchema(StructType structType) {
        return validateAndTransformSchema(structType, true, featuresDataType());
    }

    public RDD<LabeledPoint> extractLabeledPoints(Dataset<?> dataset) {
        return dataset.select(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col((String) $(labelCol())), functions$.MODULE$.col((String) $(featuresCol()))})).rdd().map(new Predictor$$anonfun$extractLabeledPoints$1(this), ClassTag$.MODULE$.apply(LabeledPoint.class));
    }

    @Override // org.apache.spark.ml.Estimator
    public /* bridge */ /* synthetic */ Model fit(Dataset dataset) {
        return fit((Dataset<?>) dataset);
    }

    public Predictor() {
        HasLabelCol.Cclass.$init$(this);
        HasFeaturesCol.Cclass.$init$(this);
        HasPredictionCol.Cclass.$init$(this);
        PredictorParams.Cclass.$init$(this);
    }
}
