/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.optim.aggregator;

import java.io.Serializable;
import java.util.Arrays;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.internal.Logging;
import org.apache.spark.ml.feature.InstanceBlock;
import org.apache.spark.ml.impl.Utils$;
import org.apache.spark.ml.linalg.BLAS$;
import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.ml.linalg.DenseVector$;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator;
import org.slf4j.Logger;
import scala.Array$;
import scala.Function0;
import scala.Function1;
import scala.Option;
import scala.Predef$;
import scala.collection.mutable.ArrayOps;
import scala.math.package$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;
import scala.runtime.java8.JFunction1;

@ScalaSignature(bytes="\u0006\u000194Qa\u0004\t\u0001)qA\u0001\u0002\u000e\u0001\u0003\u0002\u0003\u0006IA\u000e\u0005\t\u0005\u0002\u0011\t\u0011)A\u0005m!A1\t\u0001B\u0001B\u0003%A\t\u0003\u0005H\u0001\t\u0005\t\u0015!\u0003E\u0011!A\u0005A!A!\u0002\u0013I\u0005\"\u0002)\u0001\t\u0003\t\u0006b\u0002-\u0001\u0005\u0004%I!\u0017\u0005\u0007;\u0002\u0001\u000b\u0011\u0002.\t\u000fy\u0003!\u0019!C)3\"1q\f\u0001Q\u0001\niC\u0001\u0002\u0019\u0001\t\u0006\u0004%I!\u0019\u0005\bM\u0002\u0011\r\u0011\"\u0003h\u0011\u0019A\u0007\u0001)A\u0005\u007f!)\u0011\u000e\u0001C\u0001U\ni\")\u001b8befdunZ5ti&\u001c'\t\\8dW\u0006;wM]3hCR|'O\u0003\u0002\u0012%\u0005Q\u0011mZ4sK\u001e\fGo\u001c:\u000b\u0005M!\u0012!B8qi&l'BA\u000b\u0017\u0003\tiGN\u0003\u0002\u00181\u0005)1\u000f]1sW*\u0011\u0011DG\u0001\u0007CB\f7\r[3\u000b\u0003m\t1a\u001c:h'\u0011\u0001Qd\t\u0018\u0011\u0005y\tS\"A\u0010\u000b\u0003\u0001\nQa]2bY\u0006L!AI\u0010\u0003\r\u0005s\u0017PU3g!\u0011!SeJ\u0017\u000e\u0003AI!A\n\t\u00039\u0011KgMZ3sK:$\u0018.\u00192mK2{7o]!hOJ,w-\u0019;peB\u0011\u0001fK\u0007\u0002S)\u0011!\u0006F\u0001\bM\u0016\fG/\u001e:f\u0013\ta\u0013FA\u0007J]N$\u0018M\\2f\u00052|7m\u001b\t\u0003I\u0001\u0001\"a\f\u001a\u000e\u0003AR!!\r\f\u0002\u0011%tG/\u001a:oC2L!a\r\u0019\u0003\u000f1{wmZ5oO\u0006a!mY%om\u0016\u00148/Z*uI\u000e\u0001\u0001cA\u001c;y5\t\u0001H\u0003\u0002:-\u0005I!M]8bI\u000e\f7\u000f^\u0005\u0003wa\u0012\u0011B\u0011:pC\u0012\u001c\u0017m\u001d;\u0011\u0007yit(\u0003\u0002??\t)\u0011I\u001d:bsB\u0011a\u0004Q\u0005\u0003\u0003~\u0011a\u0001R8vE2,\u0017\u0001\u00042d'\u000e\fG.\u001a3NK\u0006t\u0017\u0001\u00044ji&sG/\u001a:dKB$\bC\u0001\u0010F\u0013\t1uDA\u0004C_>dW-\u00198\u0002\u0017\u0019LGoV5uQ6+\u0017M\\\u0001\u000fE\u000e\u001cu.\u001a4gS\u000eLWM\u001c;t!\r9$H\u0013\t\u0003\u0017:k\u0011\u0001\u0014\u0006\u0003\u001bR\ta\u0001\\5oC2<\u0017BA(M\u0005\u00191Vm\u0019;pe\u00061A(\u001b8jiz\"RA\u0015+V-^#\"!L*\t\u000b!3\u0001\u0019A%\t\u000bQ2\u0001\u0019\u0001\u001c\t\u000b\t3\u0001\u0019\u0001\u001c\t\u000b\r3\u0001\u0019\u0001#\t\u000b\u001d3\u0001\u0019\u0001#\u0002\u00179,XNR3biV\u0014Xm]\u000b\u00025B\u0011adW\u0005\u00039~\u00111!\u00138u\u00031qW/\u001c$fCR,(/Z:!\u0003\r!\u0017.\\\u0001\u0005I&l\u0007%A\td_\u00164g-[2jK:$8/\u0011:sCf,\u0012\u0001\u0010\u0015\u0003\u0017\r\u0004\"A\b3\n\u0005\u0015|\"!\u0003;sC:\u001c\u0018.\u001a8u\u00031i\u0017M]4j]>3gm]3u+\u0005y\u0014!D7be\u001eLgn\u00144gg\u0016$\b%A\u0002bI\u0012$\"a\u001b7\u000e\u0003\u0001AQ!\u001c\bA\u0002\u001d\nQA\u00197pG.\u0004")
public class BinaryLogisticBlockAggregator
implements DifferentiableLossAggregator<InstanceBlock, BinaryLogisticBlockAggregator>,
Logging {
    private transient double[] coefficientsArray;
    private final Broadcast<double[]> bcScaledMean;
    private final boolean fitIntercept;
    private final boolean fitWithMean;
    private final Broadcast<Vector> bcCoefficients;
    private final int numFeatures;
    private final int dim;
    private final double marginOffset;
    private transient Logger org$apache$spark$internal$Logging$$log_;
    private double weightSum;
    private double lossSum;
    private double[] gradientSumArray;
    private volatile transient boolean bitmap$trans$0;
    private volatile boolean bitmap$0;

    public String logName() {
        return Logging.logName$((Logging)this);
    }

    public Logger log() {
        return Logging.log$((Logging)this);
    }

    public void logInfo(Function0<String> msg) {
        Logging.logInfo$((Logging)this, msg);
    }

    public void logDebug(Function0<String> msg) {
        Logging.logDebug$((Logging)this, msg);
    }

    public void logTrace(Function0<String> msg) {
        Logging.logTrace$((Logging)this, msg);
    }

    public void logWarning(Function0<String> msg) {
        Logging.logWarning$((Logging)this, msg);
    }

    public void logError(Function0<String> msg) {
        Logging.logError$((Logging)this, msg);
    }

    public void logInfo(Function0<String> msg, Throwable throwable) {
        Logging.logInfo$((Logging)this, msg, (Throwable)throwable);
    }

    public void logDebug(Function0<String> msg, Throwable throwable) {
        Logging.logDebug$((Logging)this, msg, (Throwable)throwable);
    }

    public void logTrace(Function0<String> msg, Throwable throwable) {
        Logging.logTrace$((Logging)this, msg, (Throwable)throwable);
    }

    public void logWarning(Function0<String> msg, Throwable throwable) {
        Logging.logWarning$((Logging)this, msg, (Throwable)throwable);
    }

    public void logError(Function0<String> msg, Throwable throwable) {
        Logging.logError$((Logging)this, msg, (Throwable)throwable);
    }

    public boolean isTraceEnabled() {
        return Logging.isTraceEnabled$((Logging)this);
    }

    public void initializeLogIfNecessary(boolean isInterpreter) {
        Logging.initializeLogIfNecessary$((Logging)this, (boolean)isInterpreter);
    }

    public boolean initializeLogIfNecessary(boolean isInterpreter, boolean silent) {
        return Logging.initializeLogIfNecessary$((Logging)this, (boolean)isInterpreter, (boolean)silent);
    }

    public boolean initializeLogIfNecessary$default$2() {
        return Logging.initializeLogIfNecessary$default$2$((Logging)this);
    }

    public void initializeForcefully(boolean isInterpreter, boolean silent) {
        Logging.initializeForcefully$((Logging)this, (boolean)isInterpreter, (boolean)silent);
    }

    @Override
    public DifferentiableLossAggregator merge(DifferentiableLossAggregator other) {
        return DifferentiableLossAggregator.merge$(this, other);
    }

    @Override
    public Vector gradient() {
        return DifferentiableLossAggregator.gradient$(this);
    }

    @Override
    public double weight() {
        return DifferentiableLossAggregator.weight$(this);
    }

    @Override
    public double loss() {
        return DifferentiableLossAggregator.loss$(this);
    }

    public Logger org$apache$spark$internal$Logging$$log_() {
        return this.org$apache$spark$internal$Logging$$log_;
    }

    public void org$apache$spark$internal$Logging$$log__$eq(Logger x$1) {
        this.org$apache$spark$internal$Logging$$log_ = x$1;
    }

    @Override
    public double weightSum() {
        return this.weightSum;
    }

    @Override
    public void weightSum_$eq(double x$1) {
        this.weightSum = x$1;
    }

    @Override
    public double lossSum() {
        return this.lossSum;
    }

    @Override
    public void lossSum_$eq(double x$1) {
        this.lossSum = x$1;
    }

    private double[] gradientSumArray$lzycompute() {
        BinaryLogisticBlockAggregator binaryLogisticBlockAggregator = this;
        synchronized (binaryLogisticBlockAggregator) {
            if (!this.bitmap$0) {
                this.gradientSumArray = DifferentiableLossAggregator.gradientSumArray$(this);
                this.bitmap$0 = true;
            }
        }
        return this.gradientSumArray;
    }

    @Override
    public double[] gradientSumArray() {
        if (!this.bitmap$0) {
            return this.gradientSumArray$lzycompute();
        }
        return this.gradientSumArray;
    }

    private int numFeatures() {
        return this.numFeatures;
    }

    @Override
    public int dim() {
        return this.dim;
    }

    private double[] coefficientsArray$lzycompute() {
        BinaryLogisticBlockAggregator binaryLogisticBlockAggregator = this;
        synchronized (binaryLogisticBlockAggregator) {
            if (!this.bitmap$trans$0) {
                DenseVector denseVector;
                Option option;
                Vector vector = (Vector)this.bcCoefficients.value();
                if (!(vector instanceof DenseVector) || (option = DenseVector$.MODULE$.unapply(denseVector = (DenseVector)vector)).isEmpty()) {
                    throw new IllegalArgumentException(new StringBuilder(55).append("coefficients only supports dense vector but ").append("got type ").append(this.bcCoefficients.value().getClass()).append(".)").toString());
                }
                double[] values = (double[])option.get();
                this.coefficientsArray = values;
                this.bitmap$trans$0 = true;
            }
        }
        return this.coefficientsArray;
    }

    private double[] coefficientsArray() {
        if (!this.bitmap$trans$0) {
            return this.coefficientsArray$lzycompute();
        }
        return this.coefficientsArray;
    }

    private double marginOffset() {
        return this.marginOffset;
    }

    @Override
    public BinaryLogisticBlockAggregator add(InstanceBlock block) {
        block6: {
            Predef$.MODULE$.require(block.matrix().isTransposed());
            Predef$.MODULE$.require(this.numFeatures() == block.numFeatures(), (Function0 & Serializable & scala.Serializable)() -> new StringBuilder(66).append("Dimensions mismatch when adding new ").append("instance. Expecting ").append(this.numFeatures()).append(" but got ").append(block.numFeatures()).append(".").toString());
            Predef$.MODULE$.require(block.weightIter().forall((Function1)(JFunction1.mcZD.sp & Serializable & scala.Serializable)x$1 -> x$1 >= 0.0), (Function0 & Serializable & scala.Serializable)() -> new StringBuilder(34).append("instance weights ").append(block.weightIter().mkString("[", ",", "]")).append(" has to be >= 0.0").toString());
            if (block.weightIter().forall((Function1)(JFunction1.mcZD.sp & Serializable & scala.Serializable)x$2 -> x$2 == 0.0)) {
                return this;
            }
            int size = block.size();
            double[] arr = (double[])Array$.MODULE$.ofDim(size, ClassTag$.MODULE$.Double());
            if (this.fitIntercept) {
                double offset = this.fitWithMean ? this.marginOffset() : BoxesRunTime.unboxToDouble((Object)new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(this.coefficientsArray())).last());
                Arrays.fill(arr, offset);
            }
            BLAS$.MODULE$.gemv(1.0, block.matrix(), this.coefficientsArray(), 1.0, arr);
            double localLossSum = 0.0;
            double localWeightSum = 0.0;
            double multiplierSum = 0.0;
            for (int i = 0; i < size; ++i) {
                double weight = block.getWeight().apply$mcDI$sp(i);
                localWeightSum += weight;
                if (weight > 0.0) {
                    double multiplier;
                    double label = block.getLabel(i);
                    double margin = arr[i];
                    localLossSum = label > 0.0 ? (localLossSum += weight * Utils$.MODULE$.log1pExp(-margin)) : (localLossSum += weight * (Utils$.MODULE$.log1pExp(-margin) + margin));
                    arr[i] = multiplier = weight * (1.0 / (1.0 + package$.MODULE$.exp(-margin)) - label);
                    multiplierSum += multiplier;
                    continue;
                }
                arr[i] = 0.0;
            }
            this.lossSum_$eq(this.lossSum() + localLossSum);
            this.weightSum_$eq(this.weightSum() + localWeightSum);
            if (new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(arr)).forall((Function1)(JFunction1.mcZD.sp & Serializable & scala.Serializable)x$3 -> x$3 == 0.0)) {
                return this;
            }
            BLAS$.MODULE$.gemv(1.0, block.matrix().transpose(), arr, 1.0, this.gradientSumArray());
            if (this.fitWithMean) {
                BLAS$.MODULE$.javaBLAS().daxpy(this.numFeatures(), -multiplierSum, (double[])this.bcScaledMean.value(), 1, this.gradientSumArray(), 1);
            }
            if (!this.fitIntercept) break block6;
            this.gradientSumArray()[this.numFeatures()] = this.gradientSumArray()[this.numFeatures()] + multiplierSum;
        }
        return this;
    }

    public BinaryLogisticBlockAggregator(Broadcast<double[]> bcInverseStd, Broadcast<double[]> bcScaledMean, boolean fitIntercept, boolean fitWithMean, Broadcast<Vector> bcCoefficients) {
        this.bcScaledMean = bcScaledMean;
        this.fitIntercept = fitIntercept;
        this.fitWithMean = fitWithMean;
        this.bcCoefficients = bcCoefficients;
        DifferentiableLossAggregator.$init$(this);
        Logging.$init$((Logging)this);
        if (fitWithMean) {
            Predef$.MODULE$.require(fitIntercept, (Function0 & Serializable & scala.Serializable)() -> "for training without intercept, should not center the vectors");
            Predef$.MODULE$.require(bcScaledMean != null && ((double[])bcScaledMean.value()).length == ((double[])bcInverseStd.value()).length, (Function0 & Serializable & scala.Serializable)() -> "scaled means is required when center the vectors");
        }
        this.numFeatures = ((double[])bcInverseStd.value()).length;
        this.dim = ((Vector)bcCoefficients.value()).size();
        this.marginOffset = fitWithMean ? BoxesRunTime.unboxToDouble((Object)new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(this.coefficientsArray())).last()) - BLAS$.MODULE$.javaBLAS().ddot(this.numFeatures(), this.coefficientsArray(), 1, (double[])bcScaledMean.value(), 1) : Double.NaN;
    }
}

