package org.apache.mahout.classifier.mlp;

import java.io.File;
import java.io.IOException;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.math.Arrays;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/classifier/mlp/TestMultilayerPerceptron.class */
public class TestMultilayerPerceptron extends MahoutTestCase {
    @Test
    public void testMLP() throws IOException {
        testMLP("testMLPXORLocal", false, false, 8000);
        testMLP("testMLPXORLocalWithMomentum", true, false, 4000);
        testMLP("testMLPXORLocalWithRegularization", true, true, 2000);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void testMLP(String str, boolean z, boolean z2, int i) throws IOException {
        MultilayerPerceptron multilayerPerceptron = new MultilayerPerceptron();
        multilayerPerceptron.addLayer(2, false, "Sigmoid");
        multilayerPerceptron.addLayer(3, false, "Sigmoid");
        multilayerPerceptron.addLayer(1, true, "Sigmoid");
        multilayerPerceptron.setCostFunction("Minus_Squared").setLearningRate(0.2d);
        if (z) {
            multilayerPerceptron.setMomentumWeight(0.6d);
        }
        if (z2) {
            multilayerPerceptron.setRegularizationWeight(0.01d);
        }
        double[] dArr = {new double[]{0.0d, 1.0d, 1.0d}, new double[]{0.0d, 0.0d, 0.0d}, new double[]{1.0d, 0.0d, 1.0d}, new double[]{1.0d, 1.0d, 0.0d}};
        for (int i2 = 0; i2 < i; i2++) {
            for (double[] dArr2 : dArr) {
                multilayerPerceptron.train((int) dArr2[2], new DenseVector(Arrays.copyOf(dArr2, dArr2.length - 1)));
            }
        }
        for (double[] dArr3 : dArr) {
            Vector viewPart = new DenseVector(dArr3).viewPart(0, dArr3.length - 1);
            long j = dArr3[2];
            double d = multilayerPerceptron.getOutput(viewPart).get(0);
            assertTrue((j < 0.5d && d < 0.5d) || (j >= 0.5d && d >= 0.5d));
        }
        File testTempFile = getTestTempFile(str);
        multilayerPerceptron.setModelPath(testTempFile.getAbsolutePath());
        multilayerPerceptron.writeModelToFile();
        multilayerPerceptron.close();
        MultilayerPerceptron multilayerPerceptron2 = new MultilayerPerceptron(testTempFile.getAbsolutePath());
        for (double[] dArr4 : dArr) {
            Vector viewPart2 = new DenseVector(dArr4).viewPart(0, dArr4.length - 1);
            long j2 = dArr4[2];
            double d2 = multilayerPerceptron2.getOutput(viewPart2).get(0);
            assertTrue((j2 < 0.5d && d2 < 0.5d) || (j2 >= 0.5d && d2 >= 0.5d));
        }
        multilayerPerceptron2.close();
    }
}
