package org.apache.spark.mllib.optimization;

import org.apache.spark.mllib.linalg.BLAS$;
import org.apache.spark.mllib.linalg.DenseVector;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.util.MLUtils$;
import scala.Array$;
import scala.Predef$;
import scala.math.package$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.DoubleRef;
import scala.runtime.IntRef;
import scala.runtime.RichInt$;

/* compiled from: Gradient.scala */
@ScalaSignature(bytes = "\u0006\u0001U2A!\u0002\u0004\u0001#!Aa\u0003\u0001B\u0001B\u0003%q\u0003C\u0003\u001e\u0001\u0011\u0005a\u0004C\u0003\u001e\u0001\u0011\u0005\u0011\u0005C\u0003#\u0001\u0011\u00053E\u0001\tM_\u001eL7\u000f^5d\u000fJ\fG-[3oi*\u0011q\u0001C\u0001\r_B$\u0018.\\5{CRLwN\u001c\u0006\u0003\u0013)\tQ!\u001c7mS\nT!a\u0003\u0007\u0002\u000bM\u0004\u0018M]6\u000b\u00055q\u0011AB1qC\u000eDWMC\u0001\u0010\u0003\ry'oZ\u0002\u0001'\t\u0001!\u0003\u0005\u0002\u0014)5\ta!\u0003\u0002\u0016\r\tAqI]1eS\u0016tG/\u0001\u0006ok6\u001cE.Y:tKN\u0004\"\u0001G\u000e\u000e\u0003eQ\u0011AG\u0001\u0006g\u000e\fG.Y\u0005\u00039e\u00111!\u00138u\u0003\u0019a\u0014N\\5u}Q\u0011q\u0004\t\t\u0003'\u0001AQA\u0006\u0002A\u0002]!\u0012aH\u0001\bG>l\u0007/\u001e;f)\u0015!seL\u00194!\tAR%\u0003\u0002'3\t1Ai\\;cY\u0016DQ\u0001\u000b\u0003A\u0002%\nA\u0001Z1uCB\u0011!&L\u0007\u0002W)\u0011A\u0006C\u0001\u0007Y&t\u0017\r\\4\n\u00059Z#A\u0002,fGR|'\u000fC\u00031\t\u0001\u0007A%A\u0003mC\n,G\u000eC\u00033\t\u0001\u0007\u0011&A\u0004xK&<\u0007\u000e^:\t\u000bQ\"\u0001\u0019A\u0015\u0002\u0017\r,Xn\u0012:bI&,g\u000e\u001e")
/* loaded from: input_file:org/apache/spark/mllib/optimization/LogisticGradient.class */
public class LogisticGradient extends Gradient {
    private final int numClasses;

    @Override // org.apache.spark.mllib.optimization.Gradient
    public double compute(Vector vector, double d, Vector vector2, Vector vector3) {
        int size = vector.size();
        Predef$.MODULE$.require(vector2.size() % size == 0 && this.numClasses == (vector2.size() / size) + 1);
        switch (this.numClasses) {
            case 2:
                double dot = (-1.0d) * BLAS$.MODULE$.dot(vector, vector2);
                BLAS$.MODULE$.axpy((1.0d / (1.0d + package$.MODULE$.exp(dot))) - d, vector, vector3);
                return d > ((double) 0) ? MLUtils$.MODULE$.log1pExp(dot) : MLUtils$.MODULE$.log1pExp(dot) - dot;
            default:
                if (!(vector2 instanceof DenseVector)) {
                    throw new IllegalArgumentException(new StringBuilder(49).append("weights only supports dense vector but got type ").append(vector2.getClass()).append(".").toString());
                }
                double[] values = ((DenseVector) vector2).values();
                if (!(vector3 instanceof DenseVector)) {
                    throw new IllegalArgumentException(new StringBuilder(53).append("cumGradient only supports dense vector but got type ").append(vector3.getClass()).append(".").toString());
                }
                double[] values2 = ((DenseVector) vector3).values();
                DoubleRef create = DoubleRef.create(0.0d);
                DoubleRef create2 = DoubleRef.create(Double.NEGATIVE_INFINITY);
                IntRef create3 = IntRef.create(0);
                double[] dArr = (double[]) Array$.MODULE$.tabulate(this.numClasses - 1, i -> {
                    DoubleRef create4 = DoubleRef.create(0.0d);
                    vector.foreachNonZero((i, d2) -> {
                        create4.elem += d2 * values[(i * size) + i];
                    });
                    if (i == ((int) d) - 1) {
                        create.elem = create4.elem;
                    }
                    if (create4.elem > create2.elem) {
                        create2.elem = create4.elem;
                        create3.elem = i;
                    }
                    return create4.elem;
                }, ClassTag$.MODULE$.Double());
                DoubleRef create4 = DoubleRef.create(0.0d);
                if (create2.elem > 0) {
                    RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), this.numClasses - 1).foreach$mVc$sp(i2 -> {
                        dArr[i2] = dArr[i2] - create2.elem;
                        if (i2 == create3.elem) {
                            create4.elem += package$.MODULE$.exp(-create2.elem);
                        } else {
                            create4.elem += package$.MODULE$.exp(dArr[i2]);
                        }
                    });
                } else {
                    RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), this.numClasses - 1).foreach$mVc$sp(i3 -> {
                        create4.elem += package$.MODULE$.exp(dArr[i3]);
                    });
                }
                double d2 = create4.elem;
                RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), this.numClasses - 1).foreach$mVc$sp(i4 -> {
                    double exp = (package$.MODULE$.exp(dArr[i4]) / (d2 + 1.0d)) - ((d == 0.0d || d != ((double) (i4 + 1))) ? 0.0d : 1.0d);
                    vector.foreachNonZero((i4, d3) -> {
                        int i4 = (i4 * size) + i4;
                        values2[i4] = values2[i4] + (exp * d3);
                    });
                });
                double log1p = d > 0.0d ? package$.MODULE$.log1p(d2) - create.elem : package$.MODULE$.log1p(d2);
                return create2.elem > ((double) 0) ? log1p + create2.elem : log1p;
        }
    }

    public LogisticGradient(int i) {
        this.numClasses = i;
    }

    public LogisticGradient() {
        this(2);
    }
}
