package org.apache.spark.ml.optim.aggregator;

import org.apache.spark.ml.linalg.BLAS$;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.Vectors$;
import org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator;
import scala.Array$;
import scala.Predef$;
import scala.Serializable;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;

/* compiled from: DifferentiableLossAggregator.scala */
@ScalaSignature(bytes = "\u0006\u0001\u00154\u0011\"\u0004\b\u0011\u0002\u0007\u0005!C\u0007#\t\u000b\u0015\u0002A\u0011A\u0014\t\u000f-\u0002\u0001\u0019!C\tY!9\u0001\u0007\u0001a\u0001\n#\t\u0004b\u0002\u001b\u0001\u0001\u0004%\t\u0002\f\u0005\bk\u0001\u0001\r\u0011\"\u00057\u0011\u001dA\u0004A1A\u0007\u0012eB\u0001\"\u0010\u0001\t\u0006\u0004%\tB\u0010\u0005\u0006\u0005\u00021\ta\u0011\u0005\u0006/\u0002!\t\u0001\u0017\u0005\u00067\u0002!\t\u0001\u0018\u0005\u0006G\u0002!\t\u0001\f\u0005\u0006I\u0002!\t\u0001\f\u0002\u001d\t&4g-\u001a:f]RL\u0017M\u00197f\u0019>\u001c8/Q4he\u0016<\u0017\r^8s\u0015\ty\u0001#\u0001\u0006bO\u001e\u0014XmZ1u_JT!!\u0005\n\u0002\u000b=\u0004H/[7\u000b\u0005M!\u0012AA7m\u0015\t)b#A\u0003ta\u0006\u00148N\u0003\u0002\u00181\u00051\u0011\r]1dQ\u0016T\u0011!G\u0001\u0004_J<WcA\u000eP\rN\u0019\u0001\u0001\b\u0012\u0011\u0005u\u0001S\"\u0001\u0010\u000b\u0003}\tQa]2bY\u0006L!!\t\u0010\u0003\r\u0005s\u0017PU3g!\ti2%\u0003\u0002%=\ta1+\u001a:jC2L'0\u00192mK\u00061A%\u001b8ji\u0012\u001a\u0001\u0001F\u0001)!\ti\u0012&\u0003\u0002+=\t!QK\\5u\u0003%9X-[4iiN+X.F\u0001.!\tib&\u0003\u00020=\t1Ai\\;cY\u0016\fQb^3jO\"$8+^7`I\u0015\fHC\u0001\u00153\u0011\u001d\u00194!!AA\u00025\n1\u0001\u001f\u00132\u0003\u001dawn]:Tk6\f1\u0002\\8tgN+Xn\u0018\u0013fcR\u0011\u0001f\u000e\u0005\bg\u0015\t\t\u00111\u0001.\u0003\r!\u0017.\\\u000b\u0002uA\u0011QdO\u0005\u0003yy\u00111!\u00138u\u0003A9'/\u00193jK:$8+^7BeJ\f\u00170F\u0001@!\ri\u0002)L\u0005\u0003\u0003z\u0011Q!\u0011:sCf\f1!\u00193e)\t!U\u000b\u0005\u0002F\r2\u0001A!B$\u0001\u0005\u0004A%aA!hOF\u0011\u0011\n\u0014\t\u0003;)K!a\u0013\u0010\u0003\u000f9{G\u000f[5oOB!Q\n\u0001(E\u001b\u0005q\u0001CA#P\t\u0015\u0001\u0006A1\u0001R\u0005\u0015!\u0015\r^;n#\tI%\u000b\u0005\u0002\u001e'&\u0011AK\b\u0002\u0004\u0003:L\b\"\u0002,\t\u0001\u0004q\u0015\u0001C5ogR\fgnY3\u0002\u000b5,'oZ3\u0015\u0005\u0011K\u0006\"\u0002.\n\u0001\u0004!\u0015!B8uQ\u0016\u0014\u0018\u0001C4sC\u0012LWM\u001c;\u0016\u0003u\u0003\"AX1\u000e\u0003}S!\u0001\u0019\n\u0002\r1Lg.\u00197h\u0013\t\u0011wL\u0001\u0004WK\u000e$xN]\u0001\u0007o\u0016Lw\r\u001b;\u0002\t1|7o\u001d")
/* loaded from: input_file:org/apache/spark/ml/optim/aggregator/DifferentiableLossAggregator.class */
public interface DifferentiableLossAggregator<Datum, Agg extends DifferentiableLossAggregator<Datum, Agg>> extends Serializable {
    double weightSum();

    void weightSum_$eq(double d);

    double lossSum();

    void lossSum_$eq(double d);

    int dim();

    default double[] gradientSumArray() {
        return (double[]) Array$.MODULE$.ofDim(dim(), ClassTag$.MODULE$.Double());
    }

    Agg add(Datum datum);

    default Agg merge(Agg agg) {
        Predef$.MODULE$.require(dim() == agg.dim(), () -> {
            return new StringBuilder(68).append("Dimensions mismatch when merging with another ").append(this.getClass().getSimpleName()).append(". Expecting ").append(this.dim()).append(" but got ").append(agg.dim()).append(".").toString();
        });
        if (agg.weightSum() != 0) {
            weightSum_$eq(weightSum() + agg.weightSum());
            lossSum_$eq(lossSum() + agg.lossSum());
            BLAS$.MODULE$.getBLAS(dim()).daxpy(dim(), 1.0d, agg.gradientSumArray(), 1, gradientSumArray(), 1);
        }
        return this;
    }

    default Vector gradient() {
        Predef$.MODULE$.require(weightSum() > 0.0d, () -> {
            return new StringBuilder(71).append("The effective number of instances should be ").append("greater than 0.0, but was ").append(this.weightSum()).append(".").toString();
        });
        Vector dense = Vectors$.MODULE$.dense((double[]) gradientSumArray().clone());
        BLAS$.MODULE$.scal(1.0d / weightSum(), dense);
        return dense;
    }

    default double weight() {
        return weightSum();
    }

    default double loss() {
        Predef$.MODULE$.require(weightSum() > 0.0d, () -> {
            return new StringBuilder(71).append("The effective number of instances should be ").append("greater than 0.0, but was ").append(this.weightSum()).append(".").toString();
        });
        return lossSum() / weightSum();
    }

    static void $init$(DifferentiableLossAggregator differentiableLossAggregator) {
        differentiableLossAggregator.weightSum_$eq(0.0d);
        differentiableLossAggregator.lossSum_$eq(0.0d);
    }
}
