package org.apache.spark.ml.regression;

import breeze.linalg.DenseVector;
import breeze.linalg.DenseVector$;
import breeze.optimize.CachedDiffFunction;
import breeze.optimize.FirstOrderMinimizer;
import breeze.optimize.LBFGS;
import java.io.IOException;
import org.apache.spark.SparkException;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.ml.PredictionModel;
import org.apache.spark.ml.feature.Instance;
import org.apache.spark.ml.feature.InstanceBlock;
import org.apache.spark.ml.feature.InstanceBlock$;
import org.apache.spark.ml.feature.StandardScalerModel$;
import org.apache.spark.ml.linalg.BLAS$;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.Vectors$;
import org.apache.spark.ml.optim.aggregator.AFTBlockAggregator;
import org.apache.spark.ml.optim.loss.RDDLossFunction;
import org.apache.spark.ml.param.BooleanParam;
import org.apache.spark.ml.param.DoubleArrayParam;
import org.apache.spark.ml.param.DoubleParam;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.param.ParamValidators$;
import org.apache.spark.ml.param.shared.HasAggregationDepth;
import org.apache.spark.ml.param.shared.HasFitIntercept;
import org.apache.spark.ml.param.shared.HasMaxBlockSizeInMB;
import org.apache.spark.ml.stat.Summarizer$;
import org.apache.spark.ml.stat.SummarizerBuffer;
import org.apache.spark.ml.util.DatasetUtils$;
import org.apache.spark.ml.util.DefaultParamsWritable;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.ml.util.Instrumentation$;
import org.apache.spark.ml.util.MLReader;
import org.apache.spark.ml.util.MLWritable;
import org.apache.spark.ml.util.MLWriter;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row$;
import org.apache.spark.sql.functions$;
import org.apache.spark.sql.types.DoubleType$;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.storage.StorageLevel;
import org.apache.spark.storage.StorageLevel$;
import scala.Array$;
import scala.Function1;
import scala.MatchError;
import scala.None$;
import scala.Predef$;
import scala.Some;
import scala.Tuple2;
import scala.collection.Iterator;
import scala.collection.SeqLike;
import scala.collection.mutable.ArrayBuilder;
import scala.collection.mutable.ArrayBuilder$;
import scala.collection.mutable.ArrayOps;
import scala.math.package$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;
import scala.runtime.RichDouble$;
import scala.runtime.RichInt$;
import scala.runtime.ScalaRunTime$;

/* compiled from: AFTSurvivalRegression.scala */
@ScalaSignature(bytes = "\u0006\u0001\tMa\u0001\u0002\f\u0018\u0001\tB\u0001\u0002\u0011\u0001\u0003\u0006\u0004%\t%\u0011\u0005\t1\u0002\u0011\t\u0011)A\u0005\u0005\")!\f\u0001C\u00017\")!\f\u0001C\u0001?\")\u0011\r\u0001C\u0001E\")q\r\u0001C\u0001Q\")!\u000f\u0001C\u0001g\")a\u000f\u0001C\u0001o\")Q\u0010\u0001C\u0001}\"9\u0011\u0011\u0002\u0001\u0005\u0002\u0005-\u0001bBA\t\u0001\u0011\u0005\u00111\u0003\u0005\b\u0003;\u0001A\u0011AA\u0010\u0011\u001d\tI\u0003\u0001C)\u0003WAq!!\u0016\u0001\t\u0013\t9\u0006C\u0004\u0002*\u0002!\t%a+\t\u000f\u0005}\u0006\u0001\"\u0011\u0002B\u001e9\u0011q[\f\t\u0002\u0005egA\u0002\f\u0018\u0011\u0003\tY\u000e\u0003\u0004[%\u0011\u0005\u0011q\u001e\u0005\b\u0003c\u0014B\u0011IAz\u0011%\tYPEA\u0001\n\u0013\tiPA\u000bB\rR\u001bVO\u001d<jm\u0006d'+Z4sKN\u001c\u0018n\u001c8\u000b\u0005aI\u0012A\u0003:fOJ,7o]5p]*\u0011!dG\u0001\u0003[2T!\u0001H\u000f\u0002\u000bM\u0004\u0018M]6\u000b\u0005yy\u0012AB1qC\u000eDWMC\u0001!\u0003\ry'oZ\u0002\u0001'\u0015\u00011%\r\u001b;!\u0015!SeJ\u0017/\u001b\u00059\u0012B\u0001\u0014\u0018\u0005%\u0011Vm\u001a:fgN|'\u000f\u0005\u0002)W5\t\u0011F\u0003\u0002+3\u00051A.\u001b8bY\u001eL!\u0001L\u0015\u0003\rY+7\r^8s!\t!\u0003\u0001\u0005\u0002%_%\u0011\u0001g\u0006\u0002\u001b\u0003\u001a#6+\u001e:wSZ\fGNU3he\u0016\u001c8/[8o\u001b>$W\r\u001c\t\u0003IIJ!aM\f\u00037\u00053EkU;sm&4\u0018\r\u001c*fOJ,7o]5p]B\u000b'/Y7t!\t)\u0004(D\u00017\u0015\t9\u0014$\u0001\u0003vi&d\u0017BA\u001d7\u0005U!UMZ1vYR\u0004\u0016M]1ng^\u0013\u0018\u000e^1cY\u0016\u0004\"a\u000f \u000e\u0003qR!!P\u000e\u0002\u0011%tG/\u001a:oC2L!a\u0010\u001f\u0003\u000f1{wmZ5oO\u0006\u0019Q/\u001b3\u0016\u0003\t\u0003\"a\u0011'\u000f\u0005\u0011S\u0005CA#I\u001b\u00051%BA$\"\u0003\u0019a$o\\8u})\t\u0011*A\u0003tG\u0006d\u0017-\u0003\u0002L\u0011\u00061\u0001K]3eK\u001aL!!\u0014(\u0003\rM#(/\u001b8h\u0015\tY\u0005\nK\u0002\u0002!Z\u0003\"!\u0015+\u000e\u0003IS!aU\u000e\u0002\u0015\u0005tgn\u001c;bi&|g.\u0003\u0002V%\n)1+\u001b8dK\u0006\nq+A\u00032]Yr\u0003'\u0001\u0003vS\u0012\u0004\u0003f\u0001\u0002Q-\u00061A(\u001b8jiz\"\"!\f/\t\u000b\u0001\u001b\u0001\u0019\u0001\")\u0007q\u0003f\u000bK\u0002\u0004!Z#\u0012!\f\u0015\u0004\tA3\u0016\u0001D:fi\u000e+gn]8s\u0007>dGCA2e\u001b\u0005\u0001\u0001\"B3\u0006\u0001\u0004\u0011\u0015!\u0002<bYV,\u0007fA\u0003Q-\u0006A2/\u001a;Rk\u0006tG/\u001b7f!J|'-\u00192jY&$\u0018.Z:\u0015\u0005\rL\u0007\"B3\u0007\u0001\u0004Q\u0007cA6m]6\t\u0001*\u0003\u0002n\u0011\n)\u0011I\u001d:bsB\u00111n\\\u0005\u0003a\"\u0013a\u0001R8vE2,\u0007f\u0001\u0004Q-\u0006y1/\u001a;Rk\u0006tG/\u001b7fg\u000e{G\u000e\u0006\u0002di\")Qm\u0002a\u0001\u0005\"\u001aq\u0001\u0015,\u0002\u001fM,GOR5u\u0013:$XM]2faR$\"a\u0019=\t\u000b\u0015D\u0001\u0019A=\u0011\u0005-T\u0018BA>I\u0005\u001d\u0011un\u001c7fC:D3\u0001\u0003)W\u0003)\u0019X\r^'bq&#XM\u001d\u000b\u0003G~Da!Z\u0005A\u0002\u0005\u0005\u0001cA6\u0002\u0004%\u0019\u0011Q\u0001%\u0003\u0007%sG\u000fK\u0002\n!Z\u000baa]3u)>dGcA2\u0002\u000e!)QM\u0003a\u0001]\"\u001a!\u0002\u0015,\u0002'M,G/Q4he\u0016<\u0017\r^5p]\u0012+\u0007\u000f\u001e5\u0015\u0007\r\f)\u0002\u0003\u0004f\u0017\u0001\u0007\u0011\u0011\u0001\u0015\u0005\u0017A\u000bI\"\t\u0002\u0002\u001c\u0005)!GL\u0019/a\u0005\u00192/\u001a;NCb\u0014En\\2l'&TX-\u00138N\u0005R\u00191-!\t\t\u000b\u0015d\u0001\u0019\u00018)\t1\u0001\u0016QE\u0011\u0003\u0003O\tQa\r\u00182]A\nQ\u0001\u001e:bS:$2ALA\u0017\u0011\u001d\ty#\u0004a\u0001\u0003c\tq\u0001Z1uCN,G\u000f\r\u0003\u00024\u0005\r\u0003CBA\u001b\u0003w\ty$\u0004\u0002\u00028)\u0019\u0011\u0011H\u000e\u0002\u0007M\fH.\u0003\u0003\u0002>\u0005]\"a\u0002#bi\u0006\u001cX\r\u001e\t\u0005\u0003\u0003\n\u0019\u0005\u0004\u0001\u0005\u0019\u0005\u0015\u0013QFA\u0001\u0002\u0003\u0015\t!a\u0012\u0003\u0007}#\u0013'\u0005\u0003\u0002J\u0005=\u0003cA6\u0002L%\u0019\u0011Q\n%\u0003\u000f9{G\u000f[5oOB\u00191.!\u0015\n\u0007\u0005M\u0003JA\u0002B]f\f\u0011\u0002\u001e:bS:LU\u000e\u001d7\u0015\u001d\u0005e\u0013qLA>\u0003\u007f\n\u0019)a\"\u0002&B)1.a\u0017kU&\u0019\u0011Q\f%\u0003\rQ+\b\u000f\\33\u0011\u001d\t\tG\u0004a\u0001\u0003G\n\u0011\"\u001b8ti\u0006t7-Z:\u0011\r\u0005\u0015\u00141NA8\u001b\t\t9GC\u0002\u0002jm\t1A\u001d3e\u0013\u0011\ti'a\u001a\u0003\u0007I#E\t\u0005\u0003\u0002r\u0005]TBAA:\u0015\r\t)(G\u0001\bM\u0016\fG/\u001e:f\u0013\u0011\tI(a\u001d\u0003\u0011%s7\u000f^1oG\u0016Da!! \u000f\u0001\u0004q\u0017aE1diV\fGN\u00117pG.\u001c\u0016N_3J]6\u0013\u0005BBAA\u001d\u0001\u0007!.A\u0006gK\u0006$XO]3t'R$\u0007BBAC\u001d\u0001\u0007!.\u0001\u0007gK\u0006$XO]3t\u001b\u0016\fg\u000eC\u0004\u0002\n:\u0001\r!a#\u0002\u0013=\u0004H/[7ju\u0016\u0014\bCBAG\u0003/\u000bY*\u0004\u0002\u0002\u0010*!\u0011\u0011SAJ\u0003!y\u0007\u000f^5nSj,'BAAK\u0003\u0019\u0011'/Z3{K&!\u0011\u0011TAH\u0005\u0015a%IR$T!\u0015\ti*!)o\u001b\t\tyJC\u0002+\u0003'KA!a)\u0002 \nYA)\u001a8tKZ+7\r^8s\u0011\u0019\t9K\u0004a\u0001U\u0006y\u0011N\\5uS\u0006d7k\u001c7vi&|g.A\bue\u0006t7OZ8s[N\u001b\u0007.Z7b)\u0011\ti+!/\u0011\t\u0005=\u0016QW\u0007\u0003\u0003cSA!a-\u00028\u0005)A/\u001f9fg&!\u0011qWAY\u0005)\u0019FO];diRK\b/\u001a\u0005\b\u0003w{\u0001\u0019AAW\u0003\u0019\u00198\r[3nC\"\u001aq\u0002\u0015,\u0002\t\r|\u0007/\u001f\u000b\u0004[\u0005\r\u0007bBAc!\u0001\u0007\u0011qY\u0001\u0006Kb$(/\u0019\t\u0005\u0003\u0013\fy-\u0004\u0002\u0002L*\u0019\u0011QZ\r\u0002\u000bA\f'/Y7\n\t\u0005E\u00171\u001a\u0002\t!\u0006\u0014\u0018-\\'ba\"\u001a\u0001\u0003\u0015,)\u0007\u0001\u0001f+A\u000bB\rR\u001bVO\u001d<jm\u0006d'+Z4sKN\u001c\u0018n\u001c8\u0011\u0005\u0011\u00122c\u0002\n\u0002^\u0006\r\u0018\u0011\u001e\t\u0004W\u0006}\u0017bAAq\u0011\n1\u0011I\\=SK\u001a\u0004B!NAs[%\u0019\u0011q\u001d\u001c\u0003+\u0011+g-Y;miB\u000b'/Y7t%\u0016\fG-\u00192mKB\u00191.a;\n\u0007\u00055\bJ\u0001\u0007TKJL\u0017\r\\5{C\ndW\r\u0006\u0002\u0002Z\u0006!An\\1e)\ri\u0013Q\u001f\u0005\u0007\u0003o$\u0002\u0019\u0001\"\u0002\tA\fG\u000f\u001b\u0015\u0004)A3\u0016a\u0003:fC\u0012\u0014Vm]8mm\u0016$\"!a@\u0011\t\t\u0005!1B\u0007\u0003\u0005\u0007QAA!\u0002\u0003\b\u0005!A.\u00198h\u0015\t\u0011I!\u0001\u0003kCZ\f\u0017\u0002\u0002B\u0007\u0005\u0007\u0011aa\u00142kK\u000e$\bf\u0001\nQ-\"\u001a\u0011\u0003\u0015,")
/* loaded from: input_file:org/apache/spark/ml/regression/AFTSurvivalRegression.class */
public class AFTSurvivalRegression extends Regressor<Vector, AFTSurvivalRegression, AFTSurvivalRegressionModel> implements AFTSurvivalRegressionParams, DefaultParamsWritable {
    private final String uid;
    private final Param<String> censorCol;
    private final DoubleArrayParam quantileProbabilities;
    private final Param<String> quantilesCol;
    private final DoubleParam maxBlockSizeInMB;
    private final IntParam aggregationDepth;
    private final BooleanParam fitIntercept;
    private final DoubleParam tol;
    private final IntParam maxIter;

    public static AFTSurvivalRegression load(String str) {
        return AFTSurvivalRegression$.MODULE$.load(str);
    }

    public static MLReader<AFTSurvivalRegression> read() {
        return AFTSurvivalRegression$.MODULE$.read();
    }

    @Override // org.apache.spark.ml.util.DefaultParamsWritable, org.apache.spark.ml.util.MLWritable
    public MLWriter write() {
        MLWriter write;
        write = write();
        return write;
    }

    @Override // org.apache.spark.ml.util.MLWritable
    public void save(String str) throws IOException {
        save(str);
    }

    @Override // org.apache.spark.ml.regression.AFTSurvivalRegressionParams
    public String getCensorCol() {
        String censorCol;
        censorCol = getCensorCol();
        return censorCol;
    }

    @Override // org.apache.spark.ml.regression.AFTSurvivalRegressionParams
    public double[] getQuantileProbabilities() {
        double[] quantileProbabilities;
        quantileProbabilities = getQuantileProbabilities();
        return quantileProbabilities;
    }

    @Override // org.apache.spark.ml.regression.AFTSurvivalRegressionParams
    public String getQuantilesCol() {
        String quantilesCol;
        quantilesCol = getQuantilesCol();
        return quantilesCol;
    }

    @Override // org.apache.spark.ml.regression.AFTSurvivalRegressionParams
    public boolean hasQuantilesCol() {
        boolean hasQuantilesCol;
        hasQuantilesCol = hasQuantilesCol();
        return hasQuantilesCol;
    }

    @Override // org.apache.spark.ml.regression.AFTSurvivalRegressionParams
    public StructType validateAndTransformSchema(StructType structType, boolean z) {
        StructType validateAndTransformSchema;
        validateAndTransformSchema = validateAndTransformSchema(structType, z);
        return validateAndTransformSchema;
    }

    @Override // org.apache.spark.ml.param.shared.HasMaxBlockSizeInMB
    public final double getMaxBlockSizeInMB() {
        double maxBlockSizeInMB;
        maxBlockSizeInMB = getMaxBlockSizeInMB();
        return maxBlockSizeInMB;
    }

    @Override // org.apache.spark.ml.param.shared.HasAggregationDepth
    public final int getAggregationDepth() {
        int aggregationDepth;
        aggregationDepth = getAggregationDepth();
        return aggregationDepth;
    }

    @Override // org.apache.spark.ml.param.shared.HasFitIntercept
    public final boolean getFitIntercept() {
        boolean fitIntercept;
        fitIntercept = getFitIntercept();
        return fitIntercept;
    }

    @Override // org.apache.spark.ml.param.shared.HasTol
    public final double getTol() {
        double tol;
        tol = getTol();
        return tol;
    }

    @Override // org.apache.spark.ml.param.shared.HasMaxIter
    public final int getMaxIter() {
        int maxIter;
        maxIter = getMaxIter();
        return maxIter;
    }

    @Override // org.apache.spark.ml.regression.AFTSurvivalRegressionParams
    public final Param<String> censorCol() {
        return this.censorCol;
    }

    @Override // org.apache.spark.ml.regression.AFTSurvivalRegressionParams
    public final DoubleArrayParam quantileProbabilities() {
        return this.quantileProbabilities;
    }

    @Override // org.apache.spark.ml.regression.AFTSurvivalRegressionParams
    public final Param<String> quantilesCol() {
        return this.quantilesCol;
    }

    @Override // org.apache.spark.ml.regression.AFTSurvivalRegressionParams
    public final void org$apache$spark$ml$regression$AFTSurvivalRegressionParams$_setter_$censorCol_$eq(Param<String> param) {
        this.censorCol = param;
    }

    @Override // org.apache.spark.ml.regression.AFTSurvivalRegressionParams
    public final void org$apache$spark$ml$regression$AFTSurvivalRegressionParams$_setter_$quantileProbabilities_$eq(DoubleArrayParam doubleArrayParam) {
        this.quantileProbabilities = doubleArrayParam;
    }

    @Override // org.apache.spark.ml.regression.AFTSurvivalRegressionParams
    public final void org$apache$spark$ml$regression$AFTSurvivalRegressionParams$_setter_$quantilesCol_$eq(Param<String> param) {
        this.quantilesCol = param;
    }

    @Override // org.apache.spark.ml.param.shared.HasMaxBlockSizeInMB
    public final DoubleParam maxBlockSizeInMB() {
        return this.maxBlockSizeInMB;
    }

    @Override // org.apache.spark.ml.param.shared.HasMaxBlockSizeInMB
    public final void org$apache$spark$ml$param$shared$HasMaxBlockSizeInMB$_setter_$maxBlockSizeInMB_$eq(DoubleParam doubleParam) {
        this.maxBlockSizeInMB = doubleParam;
    }

    @Override // org.apache.spark.ml.param.shared.HasAggregationDepth
    public final IntParam aggregationDepth() {
        return this.aggregationDepth;
    }

    @Override // org.apache.spark.ml.param.shared.HasAggregationDepth
    public final void org$apache$spark$ml$param$shared$HasAggregationDepth$_setter_$aggregationDepth_$eq(IntParam intParam) {
        this.aggregationDepth = intParam;
    }

    @Override // org.apache.spark.ml.param.shared.HasFitIntercept
    public final BooleanParam fitIntercept() {
        return this.fitIntercept;
    }

    @Override // org.apache.spark.ml.param.shared.HasFitIntercept
    public final void org$apache$spark$ml$param$shared$HasFitIntercept$_setter_$fitIntercept_$eq(BooleanParam booleanParam) {
        this.fitIntercept = booleanParam;
    }

    @Override // org.apache.spark.ml.param.shared.HasTol
    public final DoubleParam tol() {
        return this.tol;
    }

    @Override // org.apache.spark.ml.param.shared.HasTol
    public final void org$apache$spark$ml$param$shared$HasTol$_setter_$tol_$eq(DoubleParam doubleParam) {
        this.tol = doubleParam;
    }

    @Override // org.apache.spark.ml.param.shared.HasMaxIter
    public final IntParam maxIter() {
        return this.maxIter;
    }

    @Override // org.apache.spark.ml.param.shared.HasMaxIter
    public final void org$apache$spark$ml$param$shared$HasMaxIter$_setter_$maxIter_$eq(IntParam intParam) {
        this.maxIter = intParam;
    }

    @Override // org.apache.spark.ml.util.Identifiable
    public String uid() {
        return this.uid;
    }

    public AFTSurvivalRegression setCensorCol(String str) {
        return (AFTSurvivalRegression) set((Param<Param<String>>) censorCol(), (Param<String>) str);
    }

    public AFTSurvivalRegression setQuantileProbabilities(double[] dArr) {
        return (AFTSurvivalRegression) set((Param<DoubleArrayParam>) quantileProbabilities(), (DoubleArrayParam) dArr);
    }

    public AFTSurvivalRegression setQuantilesCol(String str) {
        return (AFTSurvivalRegression) set((Param<Param<String>>) quantilesCol(), (Param<String>) str);
    }

    public AFTSurvivalRegression setFitIntercept(boolean z) {
        return (AFTSurvivalRegression) set((Param<BooleanParam>) fitIntercept(), (BooleanParam) BoxesRunTime.boxToBoolean(z));
    }

    public AFTSurvivalRegression setMaxIter(int i) {
        return (AFTSurvivalRegression) set((Param<IntParam>) maxIter(), (IntParam) BoxesRunTime.boxToInteger(i));
    }

    public AFTSurvivalRegression setTol(double d) {
        return (AFTSurvivalRegression) set((Param<DoubleParam>) tol(), (DoubleParam) BoxesRunTime.boxToDouble(d));
    }

    public AFTSurvivalRegression setAggregationDepth(int i) {
        return (AFTSurvivalRegression) set((Param<IntParam>) aggregationDepth(), (IntParam) BoxesRunTime.boxToInteger(i));
    }

    public AFTSurvivalRegression setMaxBlockSizeInMB(double d) {
        return (AFTSurvivalRegression) set((Param<DoubleParam>) maxBlockSizeInMB(), (DoubleParam) BoxesRunTime.boxToDouble(d));
    }

    @Override // org.apache.spark.ml.Predictor
    public AFTSurvivalRegressionModel train(Dataset<?> dataset) {
        return (AFTSurvivalRegressionModel) Instrumentation$.MODULE$.instrumented(instrumentation -> {
            instrumentation.logPipelineStage(this);
            instrumentation.logDataset((Dataset<?>) dataset);
            instrumentation.logParams(this, Predef$.MODULE$.wrapRefArray(new Param[]{this.labelCol(), this.featuresCol(), this.censorCol(), this.predictionCol(), this.quantilesCol(), this.fitIntercept(), this.maxIter(), this.tol(), this.aggregationDepth(), this.maxBlockSizeInMB()}));
            instrumentation.logNamedValue("quantileProbabilities.size", ((double[]) this.$(this.quantileProbabilities())).length);
            StorageLevel storageLevel = dataset.storageLevel();
            StorageLevel NONE = StorageLevel$.MODULE$.NONE();
            if (storageLevel != null ? !storageLevel.equals(NONE) : NONE != null) {
                instrumentation.logWarning(() -> {
                    return new StringBuilder(122).append("Input instances will be standardized, blockified to blocks, and ").append("then cached during training. Be careful of double caching!").toString();
                });
            }
            Column cast = functions$.MODULE$.col((String) this.$(this.censorCol())).cast(DoubleType$.MODULE$);
            RDD<Instance> name = dataset.select(Predef$.MODULE$.wrapRefArray(new Column[]{DatasetUtils$.MODULE$.checkRegressionLabels((String) this.$(this.labelCol())), functions$.MODULE$.when(cast.isNull().$bar$bar(cast.isNaN()), functions$.MODULE$.raise_error(functions$.MODULE$.lit("Censors MUST NOT be Null or NaN"))).when(cast.$eq$bang$eq(BoxesRunTime.boxToInteger(0)).$amp$amp(cast.$eq$bang$eq(BoxesRunTime.boxToInteger(1))), functions$.MODULE$.raise_error(functions$.MODULE$.concat(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.lit("Censors MUST be in {0, 1}, but got "), cast})))).otherwise(cast), DatasetUtils$.MODULE$.checkNonNanVectors((String) this.$(this.featuresCol()))})).rdd().map(row -> {
                Some unapplySeq = Row$.MODULE$.unapplySeq(row);
                if (!unapplySeq.isEmpty() && unapplySeq.get() != null && ((SeqLike) unapplySeq.get()).lengthCompare(3) == 0) {
                    Object apply = ((SeqLike) unapplySeq.get()).apply(0);
                    Object apply2 = ((SeqLike) unapplySeq.get()).apply(1);
                    Object apply3 = ((SeqLike) unapplySeq.get()).apply(2);
                    if (apply instanceof Double) {
                        double unboxToDouble = BoxesRunTime.unboxToDouble(apply);
                        if (apply2 instanceof Double) {
                            double unboxToDouble2 = BoxesRunTime.unboxToDouble(apply2);
                            if (apply3 instanceof Vector) {
                                return new Instance(unboxToDouble, unboxToDouble2, (Vector) apply3);
                            }
                        }
                    }
                }
                throw new MatchError(row);
            }, ClassTag$.MODULE$.apply(Instance.class)).setName("training instances");
            SummarizerBuffer summarizerBuffer = (SummarizerBuffer) name.treeAggregate(Summarizer$.MODULE$.createSummarizerBuffer(Predef$.MODULE$.wrapRefArray(new String[]{"mean", "std", "count"})), (summarizerBuffer2, instance) -> {
                return summarizerBuffer2.add(instance.features());
            }, (summarizerBuffer3, summarizerBuffer4) -> {
                return summarizerBuffer3.merge(summarizerBuffer4);
            }, BoxesRunTime.unboxToInt(this.$(this.aggregationDepth())), ClassTag$.MODULE$.apply(SummarizerBuffer.class));
            double[] array = summarizerBuffer.mean().toArray();
            double[] array2 = summarizerBuffer.std().toArray();
            int length = array2.length;
            instrumentation.logNumFeatures(length);
            instrumentation.logNumExamples(summarizerBuffer.count());
            double unboxToDouble = BoxesRunTime.unboxToDouble(this.$(this.maxBlockSizeInMB()));
            if (unboxToDouble == 0) {
                unboxToDouble = InstanceBlock$.MODULE$.DefaultBlockSizeInMB();
                Predef$.MODULE$.require(unboxToDouble > ((double) 0), () -> {
                    return "inferred actual BlockSizeInMB must > 0";
                });
                instrumentation.logNamedValue("actualBlockSizeInMB", Double.toString(unboxToDouble));
            }
            if (!BoxesRunTime.unboxToBoolean(this.$(this.fitIntercept())) && RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), length).exists(i -> {
                return array2[i] == 0.0d && summarizerBuffer.mean().apply(i) != 0.0d;
            })) {
                instrumentation.logWarning(() -> {
                    return "Fitting AFTSurvivalRegressionModel without intercept on dataset with constant nonzero column, Spark MLlib outputs zero coefficients for constant nonzero columns. This behavior is different from R survival::survreg.";
                });
            }
            LBFGS<DenseVector<Object>> lbfgs = new LBFGS<>(BoxesRunTime.unboxToInt(this.$(this.maxIter())), 10, BoxesRunTime.unboxToDouble(this.$(this.tol())), DenseVector$.MODULE$.space_Double());
            Tuple2<double[], double[]> trainImpl = this.trainImpl(name, unboxToDouble, array2, array, lbfgs, (double[]) Array$.MODULE$.ofDim(length + 2, ClassTag$.MODULE$.Double()));
            if (trainImpl == null) {
                throw new MatchError(trainImpl);
            }
            Tuple2 tuple2 = new Tuple2((double[]) trainImpl._1(), (double[]) trainImpl._2());
            double[] dArr = (double[]) tuple2._1();
            if (dArr != null) {
                return new AFTSurvivalRegressionModel(this.uid(), Vectors$.MODULE$.dense((double[]) Array$.MODULE$.tabulate(length, i2 -> {
                    if (array2[i2] != 0) {
                        return dArr[i2] / array2[i2];
                    }
                    return 0.0d;
                }, ClassTag$.MODULE$.Double())), dArr[length], package$.MODULE$.exp(dArr[length + 1]));
            }
            String sb = new StringBuilder(8).append(lbfgs.getClass().getName()).append(" failed.").toString();
            instrumentation.logError(() -> {
                return sb;
            });
            throw new SparkException(sb);
        });
    }

    private Tuple2<double[], double[]> trainImpl(RDD<Instance> rdd, double d, double[] dArr, double[] dArr2, LBFGS<DenseVector<Object>> lbfgs, double[] dArr3) {
        int length = dArr.length;
        double[] dArr4 = (double[]) new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(dArr)).map(d2 -> {
            if (d2 != 0) {
                return 1.0d / d2;
            }
            return 0.0d;
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Double()));
        double[] dArr5 = (double[]) Array$.MODULE$.tabulate(length, i -> {
            return dArr4[i] * dArr2[i];
        }, ClassTag$.MODULE$.Double());
        Broadcast broadcast = rdd.context().broadcast(dArr4, ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Double.TYPE)));
        Broadcast broadcast2 = rdd.context().broadcast(dArr5, ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Double.TYPE)));
        RDD name = InstanceBlock$.MODULE$.blokifyWithMaxMemUsage(rdd.mapPartitions(iterator -> {
            Function1<Vector, Vector> transformFunc = StandardScalerModel$.MODULE$.getTransformFunc((double[]) Array$.MODULE$.empty(ClassTag$.MODULE$.Double()), (double[]) broadcast.value(), false, true);
            return iterator.map(instance -> {
                if (instance != null) {
                    return new Instance(instance.label(), instance.weight(), (Vector) transformFunc.apply(instance.features()));
                }
                throw new MatchError(instance);
            });
        }, rdd.mapPartitions$default$2(), ClassTag$.MODULE$.apply(Instance.class)), (long) RichDouble$.MODULE$.ceil$extension(Predef$.MODULE$.doubleWrapper(d * 1024 * 1024))).persist(StorageLevel$.MODULE$.MEMORY_AND_DISK()).setName(new StringBuilder(32).append("training blocks (blockSizeInMB=").append(d).append(")").toString());
        RDDLossFunction rDDLossFunction = new RDDLossFunction(name, broadcast3 -> {
            return new AFTBlockAggregator(broadcast2, BoxesRunTime.unboxToBoolean(this.$(this.fitIntercept())), broadcast3);
        }, None$.MODULE$, BoxesRunTime.unboxToInt($(aggregationDepth())), ClassTag$.MODULE$.apply(InstanceBlock.class), ClassTag$.MODULE$.apply(AFTBlockAggregator.class));
        if (BoxesRunTime.unboxToBoolean($(fitIntercept()))) {
            dArr3[length] = dArr3[length] + BLAS$.MODULE$.javaBLAS().ddot(length, dArr3, 1, dArr5, 1);
        }
        Iterator iterations = lbfgs.iterations(new CachedDiffFunction(rDDLossFunction, DenseVector$.MODULE$.canCopyDenseVector(ClassTag$.MODULE$.Double())), new DenseVector.mcD.sp(dArr3));
        ArrayBuilder make = ArrayBuilder$.MODULE$.make(ClassTag$.MODULE$.Double());
        FirstOrderMinimizer.State state = null;
        while (iterations.hasNext()) {
            state = (FirstOrderMinimizer.State) iterations.next();
            make.$plus$eq(BoxesRunTime.boxToDouble(state.adjustedValue()));
        }
        name.unpersist(name.unpersist$default$1());
        broadcast.destroy();
        broadcast2.destroy();
        double[] array$mcD$sp = state == null ? null : ((DenseVector) state.x()).toArray$mcD$sp(ClassTag$.MODULE$.Double());
        if (BoxesRunTime.unboxToBoolean($(fitIntercept())) && array$mcD$sp != null) {
            array$mcD$sp[length] = array$mcD$sp[length] - BLAS$.MODULE$.getBLAS(length).ddot(length, array$mcD$sp, 1, dArr5, 1);
        }
        return new Tuple2<>(array$mcD$sp, make.result());
    }

    @Override // org.apache.spark.ml.Predictor, org.apache.spark.ml.PipelineStage
    public StructType transformSchema(StructType structType) {
        return validateAndTransformSchema(structType, true);
    }

    @Override // org.apache.spark.ml.Predictor, org.apache.spark.ml.Estimator, org.apache.spark.ml.PipelineStage, org.apache.spark.ml.param.Params
    public AFTSurvivalRegression copy(ParamMap paramMap) {
        return (AFTSurvivalRegression) defaultCopy(paramMap);
    }

    @Override // org.apache.spark.ml.Predictor
    public /* bridge */ /* synthetic */ PredictionModel train(Dataset dataset) {
        return train((Dataset<?>) dataset);
    }

    public AFTSurvivalRegression(String str) {
        this.uid = str;
        org$apache$spark$ml$param$shared$HasMaxIter$_setter_$maxIter_$eq(new IntParam(this, "maxIter", "maximum number of iterations (>= 0)", (Function1<Object, Object>) ParamValidators$.MODULE$.gtEq(0.0d)));
        org$apache$spark$ml$param$shared$HasTol$_setter_$tol_$eq(new DoubleParam(this, "tol", "the convergence tolerance for iterative algorithms (>= 0)", (Function1<Object, Object>) ParamValidators$.MODULE$.gtEq(0.0d)));
        HasFitIntercept.$init$((HasFitIntercept) this);
        HasAggregationDepth.$init$((HasAggregationDepth) this);
        HasMaxBlockSizeInMB.$init$((HasMaxBlockSizeInMB) this);
        AFTSurvivalRegressionParams.$init$((AFTSurvivalRegressionParams) this);
        MLWritable.$init$(this);
        DefaultParamsWritable.$init$((DefaultParamsWritable) this);
    }

    public AFTSurvivalRegression() {
        this(Identifiable$.MODULE$.randomUID("aftSurvReg"));
    }
}
