package org.apache.spark.ml.optim.aggregator;

import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.internal.Logging;
import org.apache.spark.ml.feature.InstanceBlock;
import org.apache.spark.ml.impl.Utils$;
import org.apache.spark.ml.linalg.BLAS$;
import org.apache.spark.ml.linalg.DenseMatrix;
import org.apache.spark.ml.linalg.DenseMatrix$;
import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.ml.linalg.DenseVector$;
import org.apache.spark.ml.linalg.Matrices$;
import org.apache.spark.ml.linalg.SparseMatrix;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.Vectors$;
import org.slf4j.Logger;
import scala.Array$;
import scala.Function0;
import scala.MatchError;
import scala.Option;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.mutable.ArrayOps;
import scala.math.Numeric$DoubleIsFractional$;
import scala.math.package$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: LogisticAggregator.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005\u0015a!B\n\u0015\u0001a\u0001\u0003\u0002\u0003\u001d\u0001\u0005\u0003\u0005\u000b\u0011\u0002\u001e\t\u0011u\u0002!\u0011!Q\u0001\niB\u0001B\u0010\u0001\u0003\u0002\u0003\u0006Ia\u0010\u0005\t\u0005\u0002\u0011\t\u0011)A\u0005\u007f!A1\t\u0001B\u0001B\u0003%A\tC\u0003Q\u0001\u0011\u0005\u0011\u000bC\u0004Y\u0001\t\u0007I\u0011B-\t\ri\u0003\u0001\u0015!\u0003;\u0011\u001dY\u0006A1A\u0005\neCa\u0001\u0018\u0001!\u0002\u0013Q\u0004bB/\u0001\u0005\u0004%\t&\u0017\u0005\u0007=\u0002\u0001\u000b\u0011\u0002\u001e\t\u0011}\u0003\u0001R1A\u0005\n\u0001D\u0001b\u001b\u0001\t\u0006\u0004%I\u0001\u001c\u0005\t]\u0002A)\u0019!C\u0005_\")A\u000f\u0001C\u0001k\")\u0011\u0010\u0001C\u0005u\"1q\u0010\u0001C\u0005\u0003\u0003\u0011qC\u00117pG.dunZ5ti&\u001c\u0017iZ4sK\u001e\fGo\u001c:\u000b\u0005U1\u0012AC1hOJ,w-\u0019;pe*\u0011q\u0003G\u0001\u0006_B$\u0018.\u001c\u0006\u00033i\t!!\u001c7\u000b\u0005ma\u0012!B:qCJ\\'BA\u000f\u001f\u0003\u0019\t\u0007/Y2iK*\tq$A\u0002pe\u001e\u001cB\u0001A\u0011(eA\u0011!%J\u0007\u0002G)\tA%A\u0003tG\u0006d\u0017-\u0003\u0002'G\t1\u0011I\\=SK\u001a\u0004B\u0001K\u0015,c5\tA#\u0003\u0002+)\taB)\u001b4gKJ,g\u000e^5bE2,Gj\\:t\u0003\u001e<'/Z4bi>\u0014\bC\u0001\u00170\u001b\u0005i#B\u0001\u0018\u0019\u0003\u001d1W-\u0019;ve\u0016L!\u0001M\u0017\u0003\u001b%s7\u000f^1oG\u0016\u0014En\\2l!\tA\u0003\u0001\u0005\u00024m5\tAG\u0003\u000265\u0005A\u0011N\u001c;fe:\fG.\u0003\u00028i\t9Aj\\4hS:<\u0017a\u00038v[\u001a+\u0017\r^;sKN\u001c\u0001\u0001\u0005\u0002#w%\u0011Ah\t\u0002\u0004\u0013:$\u0018A\u00038v[\u000ec\u0017m]:fg\u0006aa-\u001b;J]R,'oY3qiB\u0011!\u0005Q\u0005\u0003\u0003\u000e\u0012qAQ8pY\u0016\fg.A\u0006nk2$\u0018N\\8nS\u0006d\u0017A\u00042d\u0007>,gMZ5dS\u0016tGo\u001d\t\u0004\u000b\"SU\"\u0001$\u000b\u0005\u001dS\u0012!\u00032s_\u0006$7-Y:u\u0013\tIeIA\u0005Ce>\fGmY1tiB\u00111JT\u0007\u0002\u0019*\u0011Q\nG\u0001\u0007Y&t\u0017\r\\4\n\u0005=c%A\u0002,fGR|'/\u0001\u0004=S:LGO\u0010\u000b\u0006%R+fk\u0016\u000b\u0003cMCQa\u0011\u0004A\u0002\u0011CQ\u0001\u000f\u0004A\u0002iBQ!\u0010\u0004A\u0002iBQA\u0010\u0004A\u0002}BQA\u0011\u0004A\u0002}\n\u0001D\\;n\r\u0016\fG/\u001e:fgBcWo]%oi\u0016\u00148-\u001a9u+\u0005Q\u0014!\u00078v[\u001a+\u0017\r^;sKN\u0004F.^:J]R,'oY3qi\u0002\nqbY8fM\u001aL7-[3oiNK'0Z\u0001\u0011G>,gMZ5dS\u0016tGoU5{K\u0002\n1\u0001Z5n\u0003\u0011!\u0017.\u001c\u0011\u0002#\r|WM\u001a4jG&,g\u000e^:BeJ\f\u00170F\u0001b!\r\u0011#\rZ\u0005\u0003G\u000e\u0012Q!\u0011:sCf\u0004\"AI3\n\u0005\u0019\u001c#A\u0002#pk\ndW\r\u000b\u0002\u000eQB\u0011!%[\u0005\u0003U\u000e\u0012\u0011\u0002\u001e:b]NLWM\u001c;\u0002\u0019\tLg.\u0019:z\u0019&tW-\u0019:\u0016\u0003)C#A\u00045\u0002#5,H\u000e^5o_6L\u0017\r\u001c'j]\u0016\f'/F\u0001q!\tY\u0015/\u0003\u0002s\u0019\nYA)\u001a8tK6\u000bGO]5yQ\ty\u0001.A\u0002bI\u0012$\"A^<\u000e\u0003\u0001AQ\u0001\u001f\tA\u0002-\nQA\u00197pG.\f1CY5oCJLX\u000b\u001d3bi\u0016Le\u000e\u00157bG\u0016$\"a\u001f@\u0011\u0005\tb\u0018BA?$\u0005\u0011)f.\u001b;\t\u000ba\f\u0002\u0019A\u0016\u000215,H\u000e^5o_6L\u0017\r\\+qI\u0006$X-\u00138QY\u0006\u001cW\rF\u0002|\u0003\u0007AQ\u0001\u001f\nA\u0002-\u0002")
/* loaded from: input_file:org/apache/spark/ml/optim/aggregator/BlockLogisticAggregator.class */
public class BlockLogisticAggregator implements DifferentiableLossAggregator<InstanceBlock, BlockLogisticAggregator>, Logging {
    private transient double[] coefficientsArray;
    private transient Vector binaryLinear;
    private transient DenseMatrix multinomialLinear;
    private final int numFeatures;
    private final int numClasses;
    private final boolean fitIntercept;
    private final boolean multinomial;
    private final Broadcast<Vector> bcCoefficients;
    private final int numFeaturesPlusIntercept;
    private final int coefficientSize;
    private final int dim;
    private transient Logger org$apache$spark$internal$Logging$$log_;
    private double weightSum;
    private double lossSum;
    private double[] gradientSumArray;
    private volatile transient byte bitmap$trans$0;
    private volatile boolean bitmap$0;

    public String logName() {
        return Logging.logName$(this);
    }

    public Logger log() {
        return Logging.log$(this);
    }

    public void logInfo(Function0<String> function0) {
        Logging.logInfo$(this, function0);
    }

    public void logDebug(Function0<String> function0) {
        Logging.logDebug$(this, function0);
    }

    public void logTrace(Function0<String> function0) {
        Logging.logTrace$(this, function0);
    }

    public void logWarning(Function0<String> function0) {
        Logging.logWarning$(this, function0);
    }

    public void logError(Function0<String> function0) {
        Logging.logError$(this, function0);
    }

    public void logInfo(Function0<String> function0, Throwable th) {
        Logging.logInfo$(this, function0, th);
    }

    public void logDebug(Function0<String> function0, Throwable th) {
        Logging.logDebug$(this, function0, th);
    }

    public void logTrace(Function0<String> function0, Throwable th) {
        Logging.logTrace$(this, function0, th);
    }

    public void logWarning(Function0<String> function0, Throwable th) {
        Logging.logWarning$(this, function0, th);
    }

    public void logError(Function0<String> function0, Throwable th) {
        Logging.logError$(this, function0, th);
    }

    public boolean isTraceEnabled() {
        return Logging.isTraceEnabled$(this);
    }

    public void initializeLogIfNecessary(boolean z) {
        Logging.initializeLogIfNecessary$(this, z);
    }

    public boolean initializeLogIfNecessary(boolean z, boolean z2) {
        return Logging.initializeLogIfNecessary$(this, z, z2);
    }

    public boolean initializeLogIfNecessary$default$2() {
        return Logging.initializeLogIfNecessary$default$2$(this);
    }

    public void initializeForcefully(boolean z, boolean z2) {
        Logging.initializeForcefully$(this, z, z2);
    }

    /* JADX WARN: Type inference failed for: r0v1, types: [org.apache.spark.ml.optim.aggregator.BlockLogisticAggregator, org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator] */
    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public BlockLogisticAggregator merge(BlockLogisticAggregator blockLogisticAggregator) {
        ?? merge;
        merge = merge(blockLogisticAggregator);
        return merge;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public Vector gradient() {
        Vector gradient;
        gradient = gradient();
        return gradient;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public double weight() {
        double weight;
        weight = weight();
        return weight;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public double loss() {
        double loss;
        loss = loss();
        return loss;
    }

    public Logger org$apache$spark$internal$Logging$$log_() {
        return this.org$apache$spark$internal$Logging$$log_;
    }

    public void org$apache$spark$internal$Logging$$log__$eq(Logger logger) {
        this.org$apache$spark$internal$Logging$$log_ = logger;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public double weightSum() {
        return this.weightSum;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public void weightSum_$eq(double d) {
        this.weightSum = d;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public double lossSum() {
        return this.lossSum;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public void lossSum_$eq(double d) {
        this.lossSum = d;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v0 */
    /* JADX WARN: Type inference failed for: r0v1, types: [java.lang.Throwable] */
    /* JADX WARN: Type inference failed for: r0v8, types: [org.apache.spark.ml.optim.aggregator.BlockLogisticAggregator] */
    private double[] gradientSumArray$lzycompute() {
        double[] gradientSumArray;
        ?? r0 = this;
        synchronized (r0) {
            if (!this.bitmap$0) {
                gradientSumArray = gradientSumArray();
                this.gradientSumArray = gradientSumArray;
                r0 = this;
                r0.bitmap$0 = true;
            }
        }
        return this.gradientSumArray;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public double[] gradientSumArray() {
        return !this.bitmap$0 ? gradientSumArray$lzycompute() : this.gradientSumArray;
    }

    private int numFeaturesPlusIntercept() {
        return this.numFeaturesPlusIntercept;
    }

    private int coefficientSize() {
        return this.coefficientSize;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public int dim() {
        return this.dim;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private double[] coefficientsArray$lzycompute() {
        synchronized (this) {
            if (((byte) (this.bitmap$trans$0 & 1)) == 0) {
                DenseVector denseVector = (Vector) this.bcCoefficients.value();
                if (denseVector instanceof DenseVector) {
                    Option unapply = DenseVector$.MODULE$.unapply(denseVector);
                    if (!unapply.isEmpty()) {
                        this.coefficientsArray = (double[]) unapply.get();
                        this.bitmap$trans$0 = (byte) (this.bitmap$trans$0 | 1);
                    }
                }
                throw new IllegalArgumentException(new StringBuilder(55).append("coefficients only supports dense vector but ").append("got type ").append(this.bcCoefficients.value().getClass()).append(".)").toString());
            }
        }
        return this.coefficientsArray;
    }

    private double[] coefficientsArray() {
        return ((byte) (this.bitmap$trans$0 & 1)) == 0 ? coefficientsArray$lzycompute() : this.coefficientsArray;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v0 */
    /* JADX WARN: Type inference failed for: r0v1, types: [java.lang.Throwable] */
    /* JADX WARN: Type inference failed for: r0v10, types: [org.apache.spark.ml.optim.aggregator.BlockLogisticAggregator] */
    private Vector binaryLinear$lzycompute() {
        Vector vector;
        ?? r0 = this;
        synchronized (r0) {
            if (((byte) (this.bitmap$trans$0 & 2)) == 0) {
                Tuple2.mcZZ.sp spVar = new Tuple2.mcZZ.sp(this.multinomial, this.fitIntercept);
                if (spVar != null) {
                    boolean _1$mcZ$sp = spVar._1$mcZ$sp();
                    boolean _2$mcZ$sp = spVar._2$mcZ$sp();
                    if (false == _1$mcZ$sp && true == _2$mcZ$sp) {
                        vector = Vectors$.MODULE$.dense((double[]) new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(coefficientsArray())).take(this.numFeatures));
                        this.binaryLinear = vector;
                        r0 = this;
                        r0.bitmap$trans$0 = (byte) (this.bitmap$trans$0 | 2);
                    }
                }
                if (spVar != null) {
                    boolean _1$mcZ$sp2 = spVar._1$mcZ$sp();
                    boolean _2$mcZ$sp2 = spVar._2$mcZ$sp();
                    if (false == _1$mcZ$sp2 && false == _2$mcZ$sp2) {
                        vector = Vectors$.MODULE$.dense(coefficientsArray());
                        this.binaryLinear = vector;
                        r0 = this;
                        r0.bitmap$trans$0 = (byte) (this.bitmap$trans$0 | 2);
                    }
                }
                vector = null;
                this.binaryLinear = vector;
                r0 = this;
                r0.bitmap$trans$0 = (byte) (this.bitmap$trans$0 | 2);
            }
        }
        return this.binaryLinear;
    }

    private Vector binaryLinear() {
        return ((byte) (this.bitmap$trans$0 & 2)) == 0 ? binaryLinear$lzycompute() : this.binaryLinear;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v0 */
    /* JADX WARN: Type inference failed for: r0v1, types: [java.lang.Throwable] */
    /* JADX WARN: Type inference failed for: r0v10, types: [org.apache.spark.ml.optim.aggregator.BlockLogisticAggregator] */
    private DenseMatrix multinomialLinear$lzycompute() {
        DenseMatrix denseMatrix;
        ?? r0 = this;
        synchronized (r0) {
            if (((byte) (this.bitmap$trans$0 & 4)) == 0) {
                Tuple2.mcZZ.sp spVar = new Tuple2.mcZZ.sp(this.multinomial, this.fitIntercept);
                if (spVar != null) {
                    boolean _1$mcZ$sp = spVar._1$mcZ$sp();
                    boolean _2$mcZ$sp = spVar._2$mcZ$sp();
                    if (true == _1$mcZ$sp && true == _2$mcZ$sp) {
                        denseMatrix = Matrices$.MODULE$.dense(this.numClasses, this.numFeatures, (double[]) new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(coefficientsArray())).take(this.numClasses * this.numFeatures)).toDense();
                        this.multinomialLinear = denseMatrix;
                        r0 = this;
                        r0.bitmap$trans$0 = (byte) (this.bitmap$trans$0 | 4);
                    }
                }
                if (spVar != null) {
                    boolean _1$mcZ$sp2 = spVar._1$mcZ$sp();
                    boolean _2$mcZ$sp2 = spVar._2$mcZ$sp();
                    if (true == _1$mcZ$sp2 && false == _2$mcZ$sp2) {
                        denseMatrix = Matrices$.MODULE$.dense(this.numClasses, this.numFeatures, coefficientsArray()).toDense();
                        this.multinomialLinear = denseMatrix;
                        r0 = this;
                        r0.bitmap$trans$0 = (byte) (this.bitmap$trans$0 | 4);
                    }
                }
                denseMatrix = null;
                this.multinomialLinear = denseMatrix;
                r0 = this;
                r0.bitmap$trans$0 = (byte) (this.bitmap$trans$0 | 4);
            }
        }
        return this.multinomialLinear;
    }

    private DenseMatrix multinomialLinear() {
        return ((byte) (this.bitmap$trans$0 & 4)) == 0 ? multinomialLinear$lzycompute() : this.multinomialLinear;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public BlockLogisticAggregator add(InstanceBlock instanceBlock) {
        Predef$.MODULE$.require(instanceBlock.matrix().isTransposed());
        Predef$.MODULE$.require(this.numFeatures == instanceBlock.numFeatures(), () -> {
            return new StringBuilder(66).append("Dimensions mismatch when adding new ").append("instance. Expecting ").append(this.numFeatures).append(" but got ").append(instanceBlock.numFeatures()).append(".").toString();
        });
        Predef$.MODULE$.require(instanceBlock.weightIter().forall(d -> {
            return d >= ((double) 0);
        }), () -> {
            return new StringBuilder(34).append("instance weights ").append(instanceBlock.weightIter().mkString("[", ",", "]")).append(" has to be >= 0.0").toString();
        });
        if (instanceBlock.weightIter().forall(d2 -> {
            return d2 == ((double) 0);
        })) {
            return this;
        }
        if (this.multinomial) {
            multinomialUpdateInPlace(instanceBlock);
        } else {
            binaryUpdateInPlace(instanceBlock);
        }
        return this;
    }

    private void binaryUpdateInPlace(InstanceBlock instanceBlock) {
        int size = instanceBlock.size();
        DenseVector dense = this.fitIntercept ? Vectors$.MODULE$.dense((double[]) Array$.MODULE$.fill(size, () -> {
            return BoxesRunTime.unboxToDouble(new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(this.coefficientsArray())).last());
        }, ClassTag$.MODULE$.Double())).toDense() : Vectors$.MODULE$.zeros(size).toDense();
        BLAS$.MODULE$.gemv(-1.0d, instanceBlock.matrix(), binaryLinear(), -1.0d, dense);
        double d = 0.0d;
        int i = 0;
        while (true) {
            int i2 = i;
            if (i2 >= size) {
                break;
            }
            double apply$mcDI$sp = instanceBlock.getWeight().apply$mcDI$sp(i2);
            if (apply$mcDI$sp > 0) {
                double label = instanceBlock.getLabel(i2);
                double apply = dense.apply(i2);
                d = label > ((double) 0) ? d + (apply$mcDI$sp * Utils$.MODULE$.log1pExp(apply)) : d + (apply$mcDI$sp * (Utils$.MODULE$.log1pExp(apply) - apply));
                dense.values()[i2] = apply$mcDI$sp * ((1.0d / (1.0d + package$.MODULE$.exp(apply))) - label);
            } else {
                dense.values()[i2] = 0.0d;
            }
            i = i2 + 1;
        }
        lossSum_$eq(lossSum() + d);
        weightSum_$eq(weightSum() + BoxesRunTime.unboxToDouble(instanceBlock.weightIter().sum(Numeric$DoubleIsFractional$.MODULE$)));
        if (new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(dense.values())).forall(d2 -> {
            return d2 == ((double) 0);
        })) {
            return;
        }
        boolean z = false;
        SparseMatrix sparseMatrix = null;
        DenseMatrix matrix = instanceBlock.matrix();
        if (matrix instanceof DenseMatrix) {
            DenseMatrix denseMatrix = matrix;
            BLAS$.MODULE$.nativeBLAS().dgemv("N", denseMatrix.numCols(), denseMatrix.numRows(), 1.0d, denseMatrix.values(), denseMatrix.numCols(), dense.values(), 1, 1.0d, gradientSumArray(), 1);
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        } else {
            if (matrix instanceof SparseMatrix) {
                z = true;
                sparseMatrix = (SparseMatrix) matrix;
                if (this.fitIntercept) {
                    DenseVector dense2 = Vectors$.MODULE$.zeros(this.numFeatures).toDense();
                    BLAS$.MODULE$.gemv(1.0d, sparseMatrix.transpose(), dense, 0.0d, dense2);
                    BLAS$.MODULE$.getBLAS(this.numFeatures).daxpy(this.numFeatures, 1.0d, dense2.values(), 1, gradientSumArray(), 1);
                    BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
                }
            }
            if (!z || this.fitIntercept) {
                throw new IllegalArgumentException(new StringBuilder(21).append("Unknown matrix type ").append(matrix.getClass()).append(".").toString());
            }
            BLAS$.MODULE$.gemv(1.0d, sparseMatrix.transpose(), dense, 1.0d, new DenseVector(gradientSumArray()));
            BoxedUnit boxedUnit3 = BoxedUnit.UNIT;
        }
        if (this.fitIntercept) {
            gradientSumArray()[this.numFeatures] = gradientSumArray()[this.numFeatures] + BoxesRunTime.unboxToDouble(new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(dense.values())).sum(Numeric$DoubleIsFractional$.MODULE$));
        }
    }

    private void multinomialUpdateInPlace(InstanceBlock instanceBlock) {
        int size = instanceBlock.size();
        DenseMatrix zeros = DenseMatrix$.MODULE$.zeros(size, this.numClasses);
        if (this.fitIntercept) {
            double[] coefficientsArray = coefficientsArray();
            int i = this.numClasses * this.numFeatures;
            int i2 = 0;
            while (true) {
                int i3 = i2;
                if (i3 >= this.numClasses) {
                    break;
                }
                double d = coefficientsArray[i + i3];
                int i4 = 0;
                while (true) {
                    int i5 = i4;
                    if (i5 < size) {
                        zeros.update(i5, i3, d);
                        i4 = i5 + 1;
                    }
                }
                i2 = i3 + 1;
            }
        }
        BLAS$.MODULE$.gemm(1.0d, instanceBlock.matrix(), multinomialLinear().transpose(), 1.0d, zeros);
        double d2 = 0.0d;
        double[] dArr = (double[]) Array$.MODULE$.ofDim(this.numClasses, ClassTag$.MODULE$.Double());
        double[] dArr2 = this.fitIntercept ? (double[]) Array$.MODULE$.ofDim(this.numClasses, ClassTag$.MODULE$.Double()) : null;
        for (int i6 = 0; i6 < size; i6++) {
            double apply$mcDI$sp = instanceBlock.getWeight().apply$mcDI$sp(i6);
            if (apply$mcDI$sp > 0) {
                double label = instanceBlock.getLabel(i6);
                double d3 = Double.NEGATIVE_INFINITY;
                int i7 = 0;
                while (true) {
                    int i8 = i7;
                    if (i8 >= this.numClasses) {
                        break;
                    }
                    dArr[i8] = zeros.apply(i6, i8);
                    d3 = package$.MODULE$.max(d3, dArr[i8]);
                    i7 = i8 + 1;
                }
                double d4 = dArr[(int) label];
                double d5 = 0.0d;
                int i9 = 0;
                while (true) {
                    int i10 = i9;
                    if (i10 >= this.numClasses) {
                        break;
                    }
                    if (d3 > 0) {
                        dArr[i10] = dArr[i10] - d3;
                    }
                    double exp = package$.MODULE$.exp(dArr[i10]);
                    d5 += exp;
                    dArr[i10] = exp;
                    i9 = i10 + 1;
                }
                int i11 = 0;
                while (true) {
                    int i12 = i11;
                    if (i12 >= this.numClasses) {
                        break;
                    }
                    double d6 = apply$mcDI$sp * ((dArr[i12] / d5) - (label == ((double) i12) ? 1.0d : 0.0d));
                    zeros.update(i6, i12, d6);
                    if (this.fitIntercept) {
                        dArr2[i12] = dArr2[i12] + d6;
                    }
                    i11 = i12 + 1;
                }
                d2 = d3 > ((double) 0) ? d2 + (apply$mcDI$sp * ((package$.MODULE$.log(d5) - d4) + d3)) : d2 + (apply$mcDI$sp * (package$.MODULE$.log(d5) - d4));
            } else {
                int i13 = 0;
                while (true) {
                    int i14 = i13;
                    if (i14 < this.numClasses) {
                        zeros.update(i6, i14, 0.0d);
                        i13 = i14 + 1;
                    }
                }
            }
        }
        lossSum_$eq(lossSum() + d2);
        weightSum_$eq(weightSum() + BoxesRunTime.unboxToDouble(instanceBlock.weightIter().sum(Numeric$DoubleIsFractional$.MODULE$)));
        DenseMatrix matrix = instanceBlock.matrix();
        if (matrix instanceof DenseMatrix) {
            BLAS$.MODULE$.nativeBLAS().dgemm("T", "T", this.numClasses, this.numFeatures, size, 1.0d, zeros.values(), size, matrix.values(), this.numFeatures, 1.0d, gradientSumArray(), this.numClasses);
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        } else {
            if (!(matrix instanceof SparseMatrix)) {
                throw new MatchError(matrix);
            }
            SparseMatrix sparseMatrix = (SparseMatrix) matrix;
            DenseMatrix zeros2 = DenseMatrix$.MODULE$.zeros(this.numFeatures, this.numClasses);
            BLAS$.MODULE$.gemm(1.0d, sparseMatrix.transpose(), zeros, 0.0d, zeros2);
            zeros2.foreachActive((obj, obj2, obj3) -> {
                $anonfun$multinomialUpdateInPlace$4(this, BoxesRunTime.unboxToInt(obj), BoxesRunTime.unboxToInt(obj2), BoxesRunTime.unboxToDouble(obj3));
                return BoxedUnit.UNIT;
            });
            BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
        }
        if (this.fitIntercept) {
            BLAS$.MODULE$.getBLAS(this.numClasses).daxpy(this.numClasses, 1.0d, dArr2, 0, 1, gradientSumArray(), this.numClasses * this.numFeatures, 1);
        }
    }

    public static final /* synthetic */ void $anonfun$multinomialUpdateInPlace$4(BlockLogisticAggregator blockLogisticAggregator, int i, int i2, double d) {
        int i3 = (i * blockLogisticAggregator.numClasses) + i2;
        blockLogisticAggregator.gradientSumArray()[i3] = blockLogisticAggregator.gradientSumArray()[i3] + d;
    }

    public BlockLogisticAggregator(int i, int i2, boolean z, boolean z2, Broadcast<Vector> broadcast) {
        this.numFeatures = i;
        this.numClasses = i2;
        this.fitIntercept = z;
        this.multinomial = z2;
        this.bcCoefficients = broadcast;
        DifferentiableLossAggregator.$init$(this);
        Logging.$init$(this);
        if (z2 && i2 <= 2) {
            logInfo(() -> {
                return new StringBuilder(324).append("Multinomial logistic regression for binary classification yields separate ").append("coefficients for positive and negative classes. When no regularization is applied, the").append("result will be effectively the same as binary logistic regression. When regularization").append("is applied, multinomial loss will produce a result different from binary loss.").toString();
            });
        }
        this.numFeaturesPlusIntercept = z ? i + 1 : i;
        this.coefficientSize = ((Vector) broadcast.value()).size();
        this.dim = coefficientSize();
        if (z2) {
            Predef$.MODULE$.require(i2 == coefficientSize() / numFeaturesPlusIntercept(), () -> {
                return new StringBuilder(46).append("The number of ").append("coefficients should be ").append(this.numClasses * this.numFeaturesPlusIntercept()).append(" but was ").append(this.coefficientSize()).toString();
            });
        } else {
            Predef$.MODULE$.require(coefficientSize() == numFeaturesPlusIntercept(), () -> {
                return new StringBuilder(31).append("Expected ").append(this.numFeaturesPlusIntercept()).append(" ").append("coefficients but got ").append(this.coefficientSize()).toString();
            });
            Predef$.MODULE$.require(i2 == 1 || i2 == 2, () -> {
                return new StringBuilder(68).append("Binary logistic aggregator requires numClasses ").append("in {1, 2} but found ").append(this.numClasses).append(".").toString();
            });
        }
    }
}
