package org.apache.spark.ml.feature;

import java.io.IOException;
import org.apache.spark.SparkException;
import org.apache.spark.annotation.Experimental;
import org.apache.spark.ml.Transformer;
import org.apache.spark.ml.attribute.AttributeGroup;
import org.apache.spark.ml.attribute.AttributeGroup$;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.param.ParamValidators$;
import org.apache.spark.ml.param.shared.HasHandleInvalid;
import org.apache.spark.ml.param.shared.HasInputCol;
import org.apache.spark.ml.util.DefaultParamsWritable;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.ml.util.MLReader;
import org.apache.spark.ml.util.MLWritable;
import org.apache.spark.ml.util.MLWriter;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.functions$;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import scala.Function1;
import scala.MatchError;
import scala.Predef$;
import scala.reflect.ScalaSignature;
import scala.reflect.api.Mirror;
import scala.reflect.api.TypeCreator;
import scala.reflect.api.Types;
import scala.reflect.api.Universe;
import scala.runtime.BoxesRunTime;

/* compiled from: VectorSizeHint.scala */
@ScalaSignature(bytes = "\u0006\u0001\t\ra\u0001\u0002\u0010 \u0001)B\u0001\u0002\u0011\u0001\u0003\u0006\u0004%\t%\u0011\u0005\t1\u0002\u0011\t\u0011)A\u0005\u0005\")!\f\u0001C\u00017\")!\f\u0001C\u0001C\"91\r\u0001b\u0001\n\u0003!\u0007B\u00026\u0001A\u0003%Q\rC\u0003m\u0001\u0011\u0005Q\u000eC\u0003t\u0001\u0011\u0005A\u000fC\u0003z\u0001\u0011\u0005!\u0010C\u0004~\u0001\t\u0007I\u0011\t@\t\u000f\u0005\u001d\u0001\u0001)A\u0005\u007f\"9\u00111\u0002\u0001\u0005\u0002\u00055\u0001bBA\n\u0001\u0011\u0005\u0013Q\u0003\u0005\b\u0003?\u0002A\u0011BA1\u0011\u001d\t\u0019\t\u0001C!\u0003\u000bCq!a#\u0001\t\u0003\niiB\u0004\u0002&~A\t!a*\u0007\ryy\u0002\u0012AAU\u0011\u0019Q&\u0003\"\u0001\u0002>\"Q\u0011q\u0018\nC\u0002\u0013\u0005q$!1\t\u0011\u0005E'\u0003)A\u0005\u0003\u0007D!\"a5\u0013\u0005\u0004%\taHAa\u0011!\t)N\u0005Q\u0001\n\u0005\r\u0007BCAl%\t\u0007I\u0011A\u0010\u0002B\"A\u0011\u0011\u001c\n!\u0002\u0013\t\u0019\r\u0003\u0006\u0002\\J\u0011\r\u0011\"\u0001 \u0003;D\u0001\"!:\u0013A\u0003%\u0011q\u001c\u0005\b\u0003O\u0014B\u0011IAu\u0011%\t\tPEA\u0001\n\u0013\t\u0019P\u0001\bWK\u000e$xN]*ju\u0016D\u0015N\u001c;\u000b\u0005\u0001\n\u0013a\u00024fCR,(/\u001a\u0006\u0003E\r\n!!\u001c7\u000b\u0005\u0011*\u0013!B:qCJ\\'B\u0001\u0014(\u0003\u0019\t\u0007/Y2iK*\t\u0001&A\u0002pe\u001e\u001c\u0001aE\u0003\u0001W=:$\b\u0005\u0002-[5\t\u0011%\u0003\u0002/C\tYAK]1og\u001a|'/\\3s!\t\u0001T'D\u00012\u0015\t\u00114'\u0001\u0004tQ\u0006\u0014X\r\u001a\u0006\u0003i\u0005\nQ\u0001]1sC6L!AN\u0019\u0003\u0017!\u000b7/\u00138qkR\u001cu\u000e\u001c\t\u0003aaJ!!O\u0019\u0003!!\u000b7\u000fS1oI2,\u0017J\u001c<bY&$\u0007CA\u001e?\u001b\u0005a$BA\u001f\"\u0003\u0011)H/\u001b7\n\u0005}b$!\u0006#fM\u0006,H\u000e\u001e)be\u0006l7o\u0016:ji\u0006\u0014G.Z\u0001\u0004k&$W#\u0001\"\u0011\u0005\rceB\u0001#K!\t)\u0005*D\u0001G\u0015\t9\u0015&\u0001\u0004=e>|GO\u0010\u0006\u0002\u0013\u0006)1oY1mC&\u00111\nS\u0001\u0007!J,G-\u001a4\n\u00055s%AB*ue&twM\u0003\u0002L\u0011\"\u001a\u0011\u0001\u0015,\u0011\u0005E#V\"\u0001*\u000b\u0005M\u001b\u0013AC1o]>$\u0018\r^5p]&\u0011QK\u0015\u0002\u0006'&t7-Z\u0011\u0002/\u0006)!GL\u001a/a\u0005!Q/\u001b3!Q\r\u0011\u0001KV\u0001\u0007y%t\u0017\u000e\u001e \u0015\u0005qs\u0006CA/\u0001\u001b\u0005y\u0002\"\u0002!\u0004\u0001\u0004\u0011\u0005f\u00010Q-\"\u001a1\u0001\u0015,\u0015\u0003qC3\u0001\u0002)W\u0003\u0011\u0019\u0018N_3\u0016\u0003\u0015\u0004\"AZ4\u000e\u0003MJ!\u0001[\u001a\u0003\u0011%sG\u000fU1sC6D3!\u0002)W\u0003\u0015\u0019\u0018N_3!Q\r1\u0001KV\u0001\bO\u0016$8+\u001b>f+\u0005q\u0007CA8q\u001b\u0005A\u0015BA9I\u0005\rIe\u000e\u001e\u0015\u0004\u000fA3\u0016aB:fiNK'0\u001a\u000b\u0003kZl\u0011\u0001\u0001\u0005\u0006o\"\u0001\rA\\\u0001\u0006m\u0006dW/\u001a\u0015\u0004\u0011A3\u0016aC:fi&s\u0007/\u001e;D_2$\"!^>\t\u000b]L\u0001\u0019\u0001\")\u0007%\u0001f+A\u0007iC:$G.Z%om\u0006d\u0017\u000eZ\u000b\u0002\u007fB!a-!\u0001C\u0013\r\t\u0019a\r\u0002\u0006!\u0006\u0014\u0018-\u001c\u0015\u0004\u0015A3\u0016A\u00045b]\u0012dW-\u00138wC2LG\r\t\u0015\u0004\u0017A3\u0016\u0001E:fi\"\u000bg\u000e\u001a7f\u0013:4\u0018\r\\5e)\r)\u0018q\u0002\u0005\u0006o2\u0001\rA\u0011\u0015\u0004\u0019A3\u0016!\u0003;sC:\u001chm\u001c:n)\u0011\t9\"!\u000f\u0011\t\u0005e\u00111\u0007\b\u0005\u00037\tiC\u0004\u0003\u0002\u001e\u0005%b\u0002BA\u0010\u0003OqA!!\t\u0002&9\u0019Q)a\t\n\u0003!J!AJ\u0014\n\u0005\u0011*\u0013bAA\u0016G\u0005\u00191/\u001d7\n\t\u0005=\u0012\u0011G\u0001\ba\u0006\u001c7.Y4f\u0015\r\tYcI\u0005\u0005\u0003k\t9DA\u0005ECR\fgI]1nK*!\u0011qFA\u0019\u0011\u001d\tY$\u0004a\u0001\u0003{\tq\u0001Z1uCN,G\u000f\r\u0003\u0002@\u0005-\u0003CBA!\u0003\u0007\n9%\u0004\u0002\u00022%!\u0011QIA\u0019\u0005\u001d!\u0015\r^1tKR\u0004B!!\u0013\u0002L1\u0001A\u0001DA'\u0003s\t\t\u0011!A\u0003\u0002\u0005=#aA0%cE!\u0011\u0011KA,!\ry\u00171K\u0005\u0004\u0003+B%a\u0002(pi\"Lgn\u001a\t\u0004_\u0006e\u0013bAA.\u0011\n\u0019\u0011I\\=)\u00075\u0001f+A\u000bwC2LG-\u0019;f'\u000eDW-\\1B]\u0012\u001c\u0016N_3\u0015\r\u0005\r\u0014qNA@!\u0011\t)'a\u001b\u000e\u0005\u0005\u001d$bAA5C\u0005I\u0011\r\u001e;sS\n,H/Z\u0005\u0005\u0003[\n9G\u0001\bBiR\u0014\u0018NY;uK\u001e\u0013x.\u001e9\t\u000f\u0005Ed\u00021\u0001\u0002t\u000511o\u00195f[\u0006\u0004B!!\u001e\u0002|5\u0011\u0011q\u000f\u0006\u0005\u0003s\n\t$A\u0003usB,7/\u0003\u0003\u0002~\u0005]$AC*ueV\u001cG\u000fV=qK\"9\u0011\u0011\u0011\bA\u0002\u0005\r\u0014!B4s_V\u0004\u0018a\u0004;sC:\u001chm\u001c:n'\u000eDW-\\1\u0015\t\u0005M\u0014q\u0011\u0005\b\u0003cz\u0001\u0019AA:Q\ry\u0001KV\u0001\u0005G>\u0004\u0018\u0010F\u0002v\u0003\u001fCq!!%\u0011\u0001\u0004\t\u0019*A\u0003fqR\u0014\u0018\rE\u0002g\u0003+K1!a&4\u0005!\u0001\u0016M]1n\u001b\u0006\u0004\bf\u0001\tQ-\"\u001a\u0001\u0001\u0015,)\u0007\u0001\ty\nE\u0002R\u0003CK1!a)S\u00051)\u0005\u0010]3sS6,g\u000e^1m\u000391Vm\u0019;peNK'0\u001a%j]R\u0004\"!\u0018\n\u0014\u000fI\tY+!-\u00028B\u0019q.!,\n\u0007\u0005=\u0006J\u0001\u0004B]f\u0014VM\u001a\t\u0005w\u0005MF,C\u0002\u00026r\u0012Q\u0003R3gCVdG\u000fU1sC6\u001c(+Z1eC\ndW\rE\u0002p\u0003sK1!a/I\u00051\u0019VM]5bY&T\u0018M\u00197f)\t\t9+\u0001\nP!RKU*S*U\u0013\u000e{\u0016J\u0014,B\u0019&#UCAAb!\u0011\t)-a4\u000e\u0005\u0005\u001d'\u0002BAe\u0003\u0017\fA\u0001\\1oO*\u0011\u0011QZ\u0001\u0005U\u00064\u0018-C\u0002N\u0003\u000f\f1c\u0014)U\u00136K5\u000bV%D?&se+\u0011'J\t\u0002\nQ\"\u0012*S\u001fJ{\u0016J\u0014,B\u0019&#\u0015AD#S%>\u0013v,\u0013(W\u00032KE\tI\u0001\r'.K\u0005kX%O-\u0006c\u0015\nR\u0001\u000e'.K\u0005kX%O-\u0006c\u0015\n\u0012\u0011\u0002/M,\b\u000f]8si\u0016$\u0007*\u00198eY\u0016LeN^1mS\u0012\u001cXCAAp!\u0011y\u0017\u0011\u001d\"\n\u0007\u0005\r\bJA\u0003BeJ\f\u00170\u0001\rtkB\u0004xN\u001d;fI\"\u000bg\u000e\u001a7f\u0013:4\u0018\r\\5eg\u0002\nA\u0001\\8bIR\u0019A,a;\t\r\u00055H\u00041\u0001C\u0003\u0011\u0001\u0018\r\u001e5)\u0007q\u0001f+A\u0006sK\u0006$'+Z:pYZ,GCAA{!\u0011\t)-a>\n\t\u0005e\u0018q\u0019\u0002\u0007\u001f\nTWm\u0019;)\u0007I\u0001f\u000bK\u0002\u0013\u0003?C3!\u0005)WQ\r\t\u0012q\u0014")
@Experimental
/* loaded from: input_file:org/apache/spark/ml/feature/VectorSizeHint.class */
public class VectorSizeHint extends Transformer implements HasInputCol, HasHandleInvalid, DefaultParamsWritable {
    private final String uid;
    private final IntParam size;
    private final Param<String> handleInvalid;
    private final Param<String> inputCol;

    public static VectorSizeHint load(String str) {
        return VectorSizeHint$.MODULE$.load(str);
    }

    public static MLReader<VectorSizeHint> read() {
        return VectorSizeHint$.MODULE$.read();
    }

    @Override // org.apache.spark.ml.util.DefaultParamsWritable, org.apache.spark.ml.util.MLWritable
    public MLWriter write() {
        MLWriter write;
        write = write();
        return write;
    }

    @Override // org.apache.spark.ml.util.MLWritable
    public void save(String str) throws IOException {
        save(str);
    }

    @Override // org.apache.spark.ml.param.shared.HasHandleInvalid
    public final String getHandleInvalid() {
        String handleInvalid;
        handleInvalid = getHandleInvalid();
        return handleInvalid;
    }

    @Override // org.apache.spark.ml.param.shared.HasInputCol
    public final String getInputCol() {
        String inputCol;
        inputCol = getInputCol();
        return inputCol;
    }

    @Override // org.apache.spark.ml.param.shared.HasHandleInvalid
    public void org$apache$spark$ml$param$shared$HasHandleInvalid$_setter_$handleInvalid_$eq(Param<String> param) {
    }

    @Override // org.apache.spark.ml.param.shared.HasInputCol
    public final Param<String> inputCol() {
        return this.inputCol;
    }

    @Override // org.apache.spark.ml.param.shared.HasInputCol
    public final void org$apache$spark$ml$param$shared$HasInputCol$_setter_$inputCol_$eq(Param<String> param) {
        this.inputCol = param;
    }

    @Override // org.apache.spark.ml.util.Identifiable
    public String uid() {
        return this.uid;
    }

    public IntParam size() {
        return this.size;
    }

    public int getSize() {
        return BoxesRunTime.unboxToInt(getOrDefault(size()));
    }

    public VectorSizeHint setSize(int i) {
        return (VectorSizeHint) set((Param<IntParam>) size(), (IntParam) BoxesRunTime.boxToInteger(i));
    }

    public VectorSizeHint setInputCol(String str) {
        return (VectorSizeHint) set((Param<Param<String>>) inputCol(), (Param<String>) str);
    }

    @Override // org.apache.spark.ml.param.shared.HasHandleInvalid
    public Param<String> handleInvalid() {
        return this.handleInvalid;
    }

    public VectorSizeHint setHandleInvalid(String str) {
        return (VectorSizeHint) set((Param<Param<String>>) handleInvalid(), (Param<String>) str);
    }

    @Override // org.apache.spark.ml.Transformer
    public Dataset<Row> transform(Dataset<?> dataset) {
        Column apply;
        String inputCol = getInputCol();
        int size = getSize();
        String handleInvalid = getHandleInvalid();
        AttributeGroup fromStructField = AttributeGroup$.MODULE$.fromStructField(dataset.schema().apply(inputCol));
        AttributeGroup validateSchemaAndSize = validateSchemaAndSize(dataset.schema(), fromStructField);
        String OPTIMISTIC_INVALID = VectorSizeHint$.MODULE$.OPTIMISTIC_INVALID();
        if (handleInvalid != null ? handleInvalid.equals(OPTIMISTIC_INVALID) : OPTIMISTIC_INVALID == null) {
            if (fromStructField.size() == size) {
                return dataset.toDF();
            }
        }
        String OPTIMISTIC_INVALID2 = VectorSizeHint$.MODULE$.OPTIMISTIC_INVALID();
        if (OPTIMISTIC_INVALID2 != null ? !OPTIMISTIC_INVALID2.equals(handleInvalid) : handleInvalid != null) {
            String ERROR_INVALID = VectorSizeHint$.MODULE$.ERROR_INVALID();
            if (ERROR_INVALID != null ? !ERROR_INVALID.equals(handleInvalid) : handleInvalid != null) {
                String SKIP_INVALID = VectorSizeHint$.MODULE$.SKIP_INVALID();
                if (SKIP_INVALID != null ? !SKIP_INVALID.equals(handleInvalid) : handleInvalid != null) {
                    throw new MatchError(handleInvalid);
                }
                final VectorSizeHint vectorSizeHint = null;
                final VectorSizeHint vectorSizeHint2 = null;
                apply = functions$.MODULE$.udf(vector -> {
                    if (vector == null || vector.size() != size) {
                        return null;
                    }
                    return vector;
                }, scala.reflect.runtime.package$.MODULE$.universe().TypeTag().apply(scala.reflect.runtime.package$.MODULE$.universe().runtimeMirror(VectorSizeHint.class.getClassLoader()), new TypeCreator(vectorSizeHint) { // from class: org.apache.spark.ml.feature.VectorSizeHint$$typecreator3$1
                    public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                        mirror.universe();
                        return mirror.staticClass("org.apache.spark.ml.linalg.Vector").asType().toTypeConstructor();
                    }
                }), scala.reflect.runtime.package$.MODULE$.universe().TypeTag().apply(scala.reflect.runtime.package$.MODULE$.universe().runtimeMirror(VectorSizeHint.class.getClassLoader()), new TypeCreator(vectorSizeHint2) { // from class: org.apache.spark.ml.feature.VectorSizeHint$$typecreator4$1
                    public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                        mirror.universe();
                        return mirror.staticClass("org.apache.spark.ml.linalg.Vector").asType().toTypeConstructor();
                    }
                })).apply(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col(inputCol)}));
            } else {
                final VectorSizeHint vectorSizeHint3 = null;
                final VectorSizeHint vectorSizeHint4 = null;
                apply = functions$.MODULE$.udf(vector2 -> {
                    if (vector2 == null) {
                        throw new SparkException(new StringBuilder(88).append("Got null vector in VectorSizeHint, set `handleInvalid` ").append("to 'skip' to filter invalid rows.").toString());
                    }
                    if (vector2.size() != size) {
                        throw new SparkException(new StringBuilder(51).append("VectorSizeHint Expecting a vector of size ").append(size).append(" but").append(" got ").append(vector2.size()).toString());
                    }
                    return vector2;
                }, scala.reflect.runtime.package$.MODULE$.universe().TypeTag().apply(scala.reflect.runtime.package$.MODULE$.universe().runtimeMirror(VectorSizeHint.class.getClassLoader()), new TypeCreator(vectorSizeHint3) { // from class: org.apache.spark.ml.feature.VectorSizeHint$$typecreator1$1
                    public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                        mirror.universe();
                        return mirror.staticClass("org.apache.spark.ml.linalg.Vector").asType().toTypeConstructor();
                    }
                }), scala.reflect.runtime.package$.MODULE$.universe().TypeTag().apply(scala.reflect.runtime.package$.MODULE$.universe().runtimeMirror(VectorSizeHint.class.getClassLoader()), new TypeCreator(vectorSizeHint4) { // from class: org.apache.spark.ml.feature.VectorSizeHint$$typecreator2$1
                    public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                        mirror.universe();
                        return mirror.staticClass("org.apache.spark.ml.linalg.Vector").asType().toTypeConstructor();
                    }
                })).asNondeterministic().apply(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col(inputCol)}));
            }
        } else {
            apply = functions$.MODULE$.col(inputCol);
        }
        Dataset<Row> withColumn = dataset.withColumn(inputCol, apply.as(inputCol, validateSchemaAndSize.toMetadata()));
        String SKIP_INVALID2 = VectorSizeHint$.MODULE$.SKIP_INVALID();
        return (handleInvalid != null ? !handleInvalid.equals(SKIP_INVALID2) : SKIP_INVALID2 != null) ? withColumn : withColumn.na().drop(new String[]{inputCol});
    }

    private AttributeGroup validateSchemaAndSize(StructType structType, AttributeGroup attributeGroup) {
        AttributeGroup attributeGroup2;
        int size = getSize();
        String inputCol = getInputCol();
        DataType dataType = structType.apply(getInputCol()).dataType();
        Predef$.MODULE$.require(dataType instanceof VectorUDT, () -> {
            return new StringBuilder(43).append("Input column, ").append(this.getInputCol()).append(" must be of Vector type, got ").append(dataType).toString();
        });
        int size2 = attributeGroup.size();
        if (size == size2) {
            attributeGroup2 = attributeGroup;
        } else {
            if (-1 != size2) {
                throw new IllegalArgumentException(new StringBuilder(65).append("Trying to set size of vectors in `").append(inputCol).append("` to ").append(size).append(" but size ").append("already set to ").append(attributeGroup.size()).append(".").toString());
            }
            attributeGroup2 = new AttributeGroup(inputCol, size);
        }
        return attributeGroup2;
    }

    @Override // org.apache.spark.ml.PipelineStage
    public StructType transformSchema(StructType structType) {
        int fieldIndex = structType.fieldIndex(getInputCol());
        StructField[] structFieldArr = (StructField[]) structType.fields().clone();
        StructField structField = structFieldArr[fieldIndex];
        structFieldArr[fieldIndex] = structField.copy(structField.copy$default$1(), structField.copy$default$2(), structField.copy$default$3(), validateSchemaAndSize(structType, AttributeGroup$.MODULE$.fromStructField(structField)).toMetadata());
        return new StructType(structFieldArr);
    }

    @Override // org.apache.spark.ml.Transformer, org.apache.spark.ml.PipelineStage, org.apache.spark.ml.param.Params
    public VectorSizeHint copy(ParamMap paramMap) {
        return (VectorSizeHint) defaultCopy(paramMap);
    }

    public VectorSizeHint(String str) {
        this.uid = str;
        org$apache$spark$ml$param$shared$HasInputCol$_setter_$inputCol_$eq(new Param<>(this, "inputCol", "input column name"));
        org$apache$spark$ml$param$shared$HasHandleInvalid$_setter_$handleInvalid_$eq(new Param<>(this, "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an error). More options may be added later", ParamValidators$.MODULE$.inArray(new String[]{"skip", "error"})));
        MLWritable.$init$(this);
        DefaultParamsWritable.$init$((DefaultParamsWritable) this);
        this.size = new IntParam(this, "size", "Size of vectors in column.", (Function1<Object, Object>) i -> {
            return i >= 0;
        });
        this.handleInvalid = new Param<>(this, "handleInvalid", "How to handle invalid vectors in inputCol. Invalid vectors include nulls and vectors with the wrong size. The options are `skip` (filter out rows with invalid vectors), `error` (throw an error) and `optimistic` (do not check the vector size, and keep all rows). `error` by default.", ParamValidators$.MODULE$.inArray(VectorSizeHint$.MODULE$.supportedHandleInvalids()));
        setDefault(handleInvalid(), VectorSizeHint$.MODULE$.ERROR_INVALID());
    }

    public VectorSizeHint() {
        this(Identifiable$.MODULE$.randomUID("vectSizeHint"));
    }
}
