/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.ann;

import breeze.linalg.DenseMatrix;
import java.io.File;
import java.io.Serializable;
import org.apache.spark.SparkContext;
import org.apache.spark.SparkFunSuite;
import org.apache.spark.ml.ann.FeedForwardTopology;
import org.apache.spark.ml.ann.FeedForwardTopology$;
import org.apache.spark.ml.ann.Layer;
import org.apache.spark.ml.ann.LayerModel;
import org.apache.spark.ml.ann.LossFunction;
import org.apache.spark.ml.ann.SigmoidLayerWithSquaredError;
import org.apache.spark.ml.ann.SoftmaxLayerWithCrossEntropyLoss;
import org.apache.spark.ml.ann.TopologyModel;
import org.apache.spark.ml.feature.Instance;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.Vectors$;
import org.apache.spark.ml.util.TempDirectory;
import org.apache.spark.mllib.util.MLlibTestSparkContext;
import org.apache.spark.mllib.util.MLlibTestSparkContext$testImplicits$;
import org.apache.spark.sql.SparkSession;
import org.scalactic.Bool;
import org.scalactic.Bool$;
import org.scalactic.Prettifier$;
import org.scalactic.source.Position;
import org.scalatest.Assertions$;
import org.scalatest.Tag;
import scala.Array$;
import scala.Function0;
import scala.Predef$;
import scala.collection.Seq;
import scala.collection.immutable.;
import scala.collection.immutable.List;
import scala.collection.immutable.Nil$;
import scala.collection.mutable.ArrayOps;
import scala.math.package$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.java8.JFunction0;

@ScalaSignature(bytes="\u0006\u0001e2Aa\u0001\u0003\u0001\u001f!)A\u0004\u0001C\u0001;!)\u0001\u0005\u0001C\u0005C\tiqI]1eS\u0016tGoU;ji\u0016T!!\u0002\u0004\u0002\u0007\u0005tgN\u0003\u0002\b\u0011\u0005\u0011Q\u000e\u001c\u0006\u0003\u0013)\tQa\u001d9be.T!a\u0003\u0007\u0002\r\u0005\u0004\u0018m\u00195f\u0015\u0005i\u0011aA8sO\u000e\u00011c\u0001\u0001\u0011)A\u0011\u0011CE\u0007\u0002\u0011%\u00111\u0003\u0003\u0002\u000e'B\f'o\u001b$v]N+\u0018\u000e^3\u0011\u0005UQR\"\u0001\f\u000b\u0005]A\u0012\u0001B;uS2T!!\u0007\u0005\u0002\u000b5dG.\u001b2\n\u0005m1\"!F'MY&\u0014G+Z:u'B\f'o[\"p]R,\u0007\u0010^\u0001\u0007y%t\u0017\u000e\u001e \u0015\u0003y\u0001\"a\b\u0001\u000e\u0003\u0011\t1bY8naV$X\rT8tgR!!\u0005\u000b\u001a5!\t\u0019c%D\u0001%\u0015\u0005)\u0013!B:dC2\f\u0017BA\u0014%\u0005\u0019!u.\u001e2mK\")\u0011F\u0001a\u0001U\u0005)\u0011N\u001c9viB\u00191\u0006\r\u0012\u000e\u00031R!!\f\u0018\u0002\r1Lg.\u00197h\u0015\u0005y\u0013A\u00022sK\u0016TX-\u0003\u00022Y\tYA)\u001a8tK6\u000bGO]5y\u0011\u0015\u0019$\u00011\u0001+\u0003\u0019!\u0018M]4fi\")QG\u0001a\u0001m\u0005)Qn\u001c3fYB\u0011qdN\u0005\u0003q\u0011\u0011Q\u0002V8q_2|w-_'pI\u0016d\u0007")
public class GradientSuite
extends SparkFunSuite
implements MLlibTestSparkContext {
    private transient SparkSession spark;
    private transient SparkContext sc;
    private transient String checkpointDir;
    private volatile MLlibTestSparkContext$testImplicits$ testImplicits$module;
    private File org$apache$spark$ml$util$TempDirectory$$_tempDir;

    @Override
    public /* synthetic */ void org$apache$spark$mllib$util$MLlibTestSparkContext$$super$beforeAll() {
        TempDirectory.beforeAll$(this);
    }

    @Override
    public /* synthetic */ void org$apache$spark$mllib$util$MLlibTestSparkContext$$super$afterAll() {
        TempDirectory.afterAll$(this);
    }

    @Override
    public void beforeAll() {
        MLlibTestSparkContext.beforeAll$(this);
    }

    @Override
    public void afterAll() {
        MLlibTestSparkContext.afterAll$(this);
    }

    @Override
    public Instance[] standardize(Instance[] instances) {
        return MLlibTestSparkContext.standardize$(this, instances);
    }

    @Override
    public /* synthetic */ void org$apache$spark$ml$util$TempDirectory$$super$beforeAll() {
        super.beforeAll();
    }

    @Override
    public /* synthetic */ void org$apache$spark$ml$util$TempDirectory$$super$afterAll() {
        super.afterAll();
    }

    @Override
    public File tempDir() {
        return TempDirectory.tempDir$(this);
    }

    @Override
    public SparkSession spark() {
        return this.spark;
    }

    @Override
    public void spark_$eq(SparkSession x$1) {
        this.spark = x$1;
    }

    @Override
    public SparkContext sc() {
        return this.sc;
    }

    @Override
    public void sc_$eq(SparkContext x$1) {
        this.sc = x$1;
    }

    @Override
    public String checkpointDir() {
        return this.checkpointDir;
    }

    @Override
    public void checkpointDir_$eq(String x$1) {
        this.checkpointDir = x$1;
    }

    @Override
    public MLlibTestSparkContext$testImplicits$ testImplicits() {
        if (this.testImplicits$module == null) {
            this.testImplicits$lzycompute$1();
        }
        return this.testImplicits$module;
    }

    @Override
    public File org$apache$spark$ml$util$TempDirectory$$_tempDir() {
        return this.org$apache$spark$ml$util$TempDirectory$$_tempDir;
    }

    @Override
    public void org$apache$spark$ml$util$TempDirectory$$_tempDir_$eq(File x$1) {
        this.org$apache$spark$ml$util$TempDirectory$$_tempDir = x$1;
    }

    private double computeLoss(DenseMatrix<Object> input, DenseMatrix<Object> target, TopologyModel model) {
        DenseMatrix[] outputs = model.forward(input, true);
        LayerModel layerModel = (LayerModel)new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])model.layerModels())).last();
        if (layerModel instanceof LossFunction) {
            LayerModel layerModel2 = layerModel;
            return ((LossFunction)layerModel2).loss((DenseMatrix)new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])outputs)).last(), target, (DenseMatrix)new DenseMatrix.mcD.sp(target.rows(), target.cols(), ClassTag$.MODULE$.Double()));
        }
        throw new UnsupportedOperationException(new StringBuilder(49).append("Top layer is required to have loss. Failed layer:").append(new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])model.layerModels())).last().getClass()).toString());
    }

    private final void testImplicits$lzycompute$1() {
        GradientSuite gradientSuite = this;
        synchronized (gradientSuite) {
            if (this.testImplicits$module == null) {
                this.testImplicits$module = new MLlibTestSparkContext$testImplicits$(this);
            }
        }
    }

    public static final /* synthetic */ void $anonfun$new$2(GradientSuite $this, FeedForwardTopology topology$1, DenseMatrix input$1, DenseMatrix target$1, Layer layerWithError) {
        topology$1.layers()[topology$1.layers().length - 1] = layerWithError;
        TopologyModel model = topology$1.model(12L);
        double[] weights = model.weights().toArray();
        int numWeights = new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(weights)).size();
        Vector gradient = Vectors$.MODULE$.dense((double[])Array$.MODULE$.fill(numWeights, (Function0)(JFunction0.mcD.sp & Serializable & scala.Serializable)() -> 0.0, ClassTag$.MODULE$.Double()));
        double loss = model.computeGradient(input$1, target$1, gradient, 1);
        double eps = 1.0E-4;
        double tol = 1.0E-4;
        for (int i = 0; i < numWeights; ++i) {
            double originalValue = weights[i];
            int n = i;
            weights[n] = weights[n] + eps;
            TopologyModel newModel = topology$1.model(Vectors$.MODULE$.dense(weights));
            double newLoss = $this.computeLoss((DenseMatrix<Object>)input$1, (DenseMatrix<Object>)target$1, newModel);
            double derivativeEstimate = (newLoss - loss) / eps;
            double $org_scalatest_assert_macro_left = package$.MODULE$.abs(gradient.apply(i) - derivativeEstimate);
            double $org_scalatest_assert_macro_right = tol;
            Bool $org_scalatest_assert_macro_expr = Bool$.MODULE$.binaryMacroBool((Object)BoxesRunTime.boxToDouble((double)$org_scalatest_assert_macro_left), "<", (Object)BoxesRunTime.boxToDouble((double)$org_scalatest_assert_macro_right), $org_scalatest_assert_macro_left < $org_scalatest_assert_macro_right, Prettifier$.MODULE$.default());
            Assertions$.MODULE$.assertionsHelper().macroAssert($org_scalatest_assert_macro_expr, (Object)new StringBuilder(29).append("Layer failed gradient check: ").append(layerWithError.getClass()).toString(), Prettifier$.MODULE$.default(), new Position("GradientSuite.scala", "Please set the environment variable SCALACTIC_FILL_FILE_PATHNAMES to yes at compile time to enable this feature.", 58));
            weights[i] = originalValue;
        }
    }

    public GradientSuite() {
        TempDirectory.$init$(this);
        MLlibTestSparkContext.$init$(this);
        this.test("Gradient computation against numerical differentiation", (Seq)Predef$.MODULE$.wrapRefArray((Object[])new Tag[0]), (Function0)(JFunction0.mcV.sp & Serializable & scala.Serializable)() -> {
            DenseMatrix.mcD.sp input = new DenseMatrix.mcD.sp(3, 1, new double[]{1.0, 1.0, 1.0});
            DenseMatrix.mcD.sp target = new DenseMatrix.mcD.sp(2, 1, new double[]{0.0, 1.0});
            FeedForwardTopology topology = FeedForwardTopology$.MODULE$.multiLayerPerceptron(new int[]{3, 4, 2}, false);
            Seq layersWithErrors = (Seq)new .colon.colon((Object)new SigmoidLayerWithSquaredError(), (List)new .colon.colon((Object)new SoftmaxLayerWithCrossEntropyLoss(), (List)Nil$.MODULE$));
            layersWithErrors.foreach(arg_0 -> GradientSuite.$anonfun$new$2$adapted(this, topology, (DenseMatrix)input, (DenseMatrix)target, arg_0));
        }, new Position("GradientSuite.scala", "Please set the environment variable SCALACTIC_FILL_FILE_PATHNAMES to yes at compile time to enable this feature.", 28));
    }

    public static final /* synthetic */ Object $anonfun$new$2$adapted(GradientSuite $this, FeedForwardTopology topology$1, DenseMatrix input$1, DenseMatrix target$1, Layer layerWithError) {
        GradientSuite.$anonfun$new$2($this, topology$1, input$1, target$1, layerWithError);
        return BoxedUnit.UNIT;
    }
}

