package org.apache.spark.ml.optim.aggregator;

import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.ml.feature.InstanceBlock;
import org.apache.spark.ml.linalg.BLAS$;
import org.apache.spark.ml.linalg.DenseMatrix;
import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.ml.linalg.DenseVector$;
import org.apache.spark.ml.linalg.SparseMatrix;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.Vectors$;
import scala.Array$;
import scala.MatchError;
import scala.Option;
import scala.Predef$;
import scala.collection.mutable.ArrayOps;
import scala.math.Numeric$DoubleIsFractional$;
import scala.math.package$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: AFTAggregator.scala */
@ScalaSignature(bytes = "\u0006\u0001u3Qa\u0003\u0007\u0001!aA\u0001B\u000b\u0001\u0003\u0002\u0003\u0006I\u0001\f\u0005\t_\u0001\u0011\t\u0011)A\u0005a!)A\b\u0001C\u0001{!9\u0011\t\u0001b\u0001\n#\u0012\u0005B\u0002$\u0001A\u0003%1\tC\u0004H\u0001\t\u0007I\u0011\u0002\"\t\r!\u0003\u0001\u0015!\u0003D\u0011!I\u0005\u0001#b\u0001\n\u0013Q\u0005\u0002C+\u0001\u0011\u000b\u0007I\u0011\u0002,\t\u000ba\u0003A\u0011A-\u0003%\tcwnY6B\rR\u000bum\u001a:fO\u0006$xN\u001d\u0006\u0003\u001b9\t!\"Y4he\u0016<\u0017\r^8s\u0015\ty\u0001#A\u0003paRLWN\u0003\u0002\u0012%\u0005\u0011Q\u000e\u001c\u0006\u0003'Q\tQa\u001d9be.T!!\u0006\f\u0002\r\u0005\u0004\u0018m\u00195f\u0015\u00059\u0012aA8sON\u0019\u0001!G\u0010\u0011\u0005iiR\"A\u000e\u000b\u0003q\tQa]2bY\u0006L!AH\u000e\u0003\r\u0005s\u0017PU3g!\u0011\u0001\u0013eI\u0015\u000e\u00031I!A\t\u0007\u00039\u0011KgMZ3sK:$\u0018.\u00192mK2{7o]!hOJ,w-\u0019;peB\u0011AeJ\u0007\u0002K)\u0011a\u0005E\u0001\bM\u0016\fG/\u001e:f\u0013\tASEA\u0007J]N$\u0018M\\2f\u00052|7m\u001b\t\u0003A\u0001\tABZ5u\u0013:$XM]2faR\u001c\u0001\u0001\u0005\u0002\u001b[%\u0011af\u0007\u0002\b\u0005>|G.Z1o\u00039\u00117mQ8fM\u001aL7-[3oiN\u00042!\r\u001b7\u001b\u0005\u0011$BA\u001a\u0013\u0003%\u0011'o\\1eG\u0006\u001cH/\u0003\u00026e\tI!I]8bI\u000e\f7\u000f\u001e\t\u0003oij\u0011\u0001\u000f\u0006\u0003sA\ta\u0001\\5oC2<\u0017BA\u001e9\u0005\u00191Vm\u0019;pe\u00061A(\u001b8jiz\"\"A\u0010!\u0015\u0005%z\u0004\"B\u0018\u0004\u0001\u0004\u0001\u0004\"\u0002\u0016\u0004\u0001\u0004a\u0013a\u00013j[V\t1\t\u0005\u0002\u001b\t&\u0011Qi\u0007\u0002\u0004\u0013:$\u0018\u0001\u00023j[\u0002\n1B\\;n\r\u0016\fG/\u001e:fg\u0006aa.^7GK\u0006$XO]3tA\u0005\t2m\\3gM&\u001c\u0017.\u001a8ug\u0006\u0013(/Y=\u0016\u0003-\u00032A\u0007'O\u0013\ti5DA\u0003BeJ\f\u0017\u0010\u0005\u0002\u001b\u001f&\u0011\u0001k\u0007\u0002\u0007\t>,(\r\\3)\u0005!\u0011\u0006C\u0001\u000eT\u0013\t!6DA\u0005ue\u0006t7/[3oi\u00061A.\u001b8fCJ,\u0012A\u000e\u0015\u0003\u0013I\u000b1!\u00193e)\tQ6,D\u0001\u0001\u0011\u0015a&\u00021\u0001$\u0003\u0015\u0011Gn\\2l\u0001")
/* loaded from: input_file:org/apache/spark/ml/optim/aggregator/BlockAFTAggregator.class */
public class BlockAFTAggregator implements DifferentiableLossAggregator<InstanceBlock, BlockAFTAggregator> {
    private transient double[] coefficientsArray;
    private transient Vector linear;
    private final boolean fitIntercept;
    private final Broadcast<Vector> bcCoefficients;
    private final int dim;
    private final int numFeatures;
    private double weightSum;
    private double lossSum;
    private double[] gradientSumArray;
    private volatile transient byte bitmap$trans$0;
    private volatile boolean bitmap$0;

    /* JADX WARN: Type inference failed for: r0v1, types: [org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator, org.apache.spark.ml.optim.aggregator.BlockAFTAggregator] */
    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public BlockAFTAggregator merge(BlockAFTAggregator blockAFTAggregator) {
        ?? merge;
        merge = merge(blockAFTAggregator);
        return merge;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public Vector gradient() {
        Vector gradient;
        gradient = gradient();
        return gradient;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public double weight() {
        double weight;
        weight = weight();
        return weight;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public double loss() {
        double loss;
        loss = loss();
        return loss;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public double weightSum() {
        return this.weightSum;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public void weightSum_$eq(double d) {
        this.weightSum = d;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public double lossSum() {
        return this.lossSum;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public void lossSum_$eq(double d) {
        this.lossSum = d;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v0 */
    /* JADX WARN: Type inference failed for: r0v1, types: [java.lang.Throwable] */
    /* JADX WARN: Type inference failed for: r0v8, types: [org.apache.spark.ml.optim.aggregator.BlockAFTAggregator] */
    private double[] gradientSumArray$lzycompute() {
        double[] gradientSumArray;
        ?? r0 = this;
        synchronized (r0) {
            if (!this.bitmap$0) {
                gradientSumArray = gradientSumArray();
                this.gradientSumArray = gradientSumArray;
                r0 = this;
                r0.bitmap$0 = true;
            }
        }
        return this.gradientSumArray;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public double[] gradientSumArray() {
        return !this.bitmap$0 ? gradientSumArray$lzycompute() : this.gradientSumArray;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public int dim() {
        return this.dim;
    }

    private int numFeatures() {
        return this.numFeatures;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private double[] coefficientsArray$lzycompute() {
        synchronized (this) {
            if (((byte) (this.bitmap$trans$0 & 1)) == 0) {
                DenseVector denseVector = (Vector) this.bcCoefficients.value();
                if (denseVector instanceof DenseVector) {
                    Option unapply = DenseVector$.MODULE$.unapply(denseVector);
                    if (!unapply.isEmpty()) {
                        this.coefficientsArray = (double[]) unapply.get();
                        this.bitmap$trans$0 = (byte) (this.bitmap$trans$0 | 1);
                    }
                }
                throw new IllegalArgumentException(new StringBuilder(54).append("coefficients only supports dense vector").append(" but got type ").append(this.bcCoefficients.value().getClass()).append(".").toString());
            }
        }
        return this.coefficientsArray;
    }

    private double[] coefficientsArray() {
        return ((byte) (this.bitmap$trans$0 & 1)) == 0 ? coefficientsArray$lzycompute() : this.coefficientsArray;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v0 */
    /* JADX WARN: Type inference failed for: r0v1, types: [java.lang.Throwable] */
    /* JADX WARN: Type inference failed for: r0v10, types: [org.apache.spark.ml.optim.aggregator.BlockAFTAggregator] */
    private Vector linear$lzycompute() {
        ?? r0 = this;
        synchronized (r0) {
            if (((byte) (this.bitmap$trans$0 & 2)) == 0) {
                this.linear = Vectors$.MODULE$.dense((double[]) new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(coefficientsArray())).take(numFeatures()));
                r0 = this;
                r0.bitmap$trans$0 = (byte) (this.bitmap$trans$0 | 2);
            }
        }
        return this.linear;
    }

    private Vector linear() {
        return ((byte) (this.bitmap$trans$0 & 2)) == 0 ? linear$lzycompute() : this.linear;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public BlockAFTAggregator add(InstanceBlock instanceBlock) {
        Predef$.MODULE$.require(instanceBlock.matrix().isTransposed());
        Predef$.MODULE$.require(numFeatures() == instanceBlock.numFeatures(), () -> {
            return new StringBuilder(66).append("Dimensions mismatch when adding new ").append("instance. Expecting ").append(this.numFeatures()).append(" but got ").append(instanceBlock.numFeatures()).append(".").toString();
        });
        Predef$.MODULE$.require(new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(instanceBlock.labels())).forall(d -> {
            return d > 0.0d;
        }), () -> {
            return "The lifetime or label should be  greater than 0.";
        });
        int size = instanceBlock.size();
        double d2 = coefficientsArray()[dim() - 2];
        double exp = package$.MODULE$.exp(coefficientsArray()[dim() - 1]);
        DenseVector dense = this.fitIntercept ? Vectors$.MODULE$.dense((double[]) Array$.MODULE$.fill(size, () -> {
            return d2;
        }, ClassTag$.MODULE$.Double())).toDense() : Vectors$.MODULE$.zeros(size).toDense();
        BLAS$.MODULE$.gemv(1.0d, instanceBlock.matrix(), linear(), 1.0d, dense);
        double d3 = 0.0d;
        double d4 = 0.0d;
        for (int i = 0; i < size; i++) {
            double label = instanceBlock.getLabel(i);
            double apply$mcDI$sp = instanceBlock.getWeight().apply$mcDI$sp(i);
            double log = (package$.MODULE$.log(label) - dense.apply(i)) / exp;
            double exp2 = package$.MODULE$.exp(log);
            d3 += ((apply$mcDI$sp * package$.MODULE$.log(exp)) - (apply$mcDI$sp * log)) + exp2;
            double d5 = (apply$mcDI$sp - exp2) / exp;
            dense.values()[i] = d5;
            d4 += apply$mcDI$sp + (d5 * exp * log);
        }
        lossSum_$eq(lossSum() + d3);
        weightSum_$eq(weightSum() + size);
        DenseMatrix matrix = instanceBlock.matrix();
        if (matrix instanceof DenseMatrix) {
            DenseMatrix denseMatrix = matrix;
            BLAS$.MODULE$.nativeBLAS().dgemv("N", denseMatrix.numCols(), denseMatrix.numRows(), 1.0d, denseMatrix.values(), denseMatrix.numCols(), dense.values(), 1, 1.0d, gradientSumArray(), 1);
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        } else {
            if (!(matrix instanceof SparseMatrix)) {
                throw new MatchError(matrix);
            }
            SparseMatrix sparseMatrix = (SparseMatrix) matrix;
            DenseVector dense2 = Vectors$.MODULE$.zeros(numFeatures()).toDense();
            BLAS$.MODULE$.gemv(1.0d, sparseMatrix.transpose(), dense, 0.0d, dense2);
            BLAS$.MODULE$.getBLAS(numFeatures()).daxpy(numFeatures(), 1.0d, dense2.values(), 1, gradientSumArray(), 1);
            BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
        }
        if (this.fitIntercept) {
            int dim = dim() - 2;
            gradientSumArray()[dim] = gradientSumArray()[dim] + BoxesRunTime.unboxToDouble(new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(dense.values())).sum(Numeric$DoubleIsFractional$.MODULE$));
        }
        int dim2 = dim() - 1;
        gradientSumArray()[dim2] = gradientSumArray()[dim2] + d4;
        return this;
    }

    public BlockAFTAggregator(boolean z, Broadcast<Vector> broadcast) {
        this.fitIntercept = z;
        this.bcCoefficients = broadcast;
        DifferentiableLossAggregator.$init$(this);
        this.dim = ((Vector) broadcast.value()).size();
        this.numFeatures = dim() - 2;
    }
}
