/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.classifier.mlp;

import java.io.File;
import java.io.IOException;
import org.apache.mahout.classifier.mlp.MultilayerPerceptron;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.math.Arrays;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.junit.Test;

public class TestMultilayerPerceptron
extends MahoutTestCase {
    @Test
    public void testMLP() throws IOException {
        this.testMLP("testMLPXORLocal", false, false, 8000);
        this.testMLP("testMLPXORLocalWithMomentum", true, false, 4000);
        this.testMLP("testMLPXORLocalWithRegularization", true, true, 2000);
    }

    private void testMLP(String modelFilename, boolean useMomentum, boolean useRegularization, int iterations) throws IOException {
        MultilayerPerceptron mlp = new MultilayerPerceptron();
        mlp.addLayer(2, false, "Sigmoid");
        mlp.addLayer(3, false, "Sigmoid");
        mlp.addLayer(1, true, "Sigmoid");
        mlp.setCostFunction("Minus_Squared").setLearningRate(0.2);
        if (useMomentum) {
            mlp.setMomentumWeight(0.6);
        }
        if (useRegularization) {
            mlp.setRegularizationWeight(0.01);
        }
        double[][] instances = new double[][]{{0.0, 1.0, 1.0}, {0.0, 0.0, 0.0}, {1.0, 0.0, 1.0}, {1.0, 1.0, 0.0}};
        for (int i = 0; i < iterations; ++i) {
            for (double[] instance : instances) {
                DenseVector features = new DenseVector(Arrays.copyOf((double[])instance, (int)(instance.length - 1)));
                mlp.train((int)instance[2], (Vector)features);
            }
        }
        for (double[] instance : instances) {
            Vector input = new DenseVector(instance).viewPart(0, instance.length - 1);
            double actual = instance[2];
            double expected = mlp.getOutput(input).get(0);
            TestMultilayerPerceptron.assertTrue((actual < 0.5 && expected < 0.5 || actual >= 0.5 && expected >= 0.5 ? 1 : 0) != 0);
        }
        File modelFile = this.getTestTempFile(modelFilename);
        mlp.setModelPath(modelFile.getAbsolutePath());
        mlp.writeModelToFile();
        mlp.close();
        MultilayerPerceptron mlpCopy = new MultilayerPerceptron(modelFile.getAbsolutePath());
        for (double[] instance : instances) {
            Vector input = new DenseVector(instance).viewPart(0, instance.length - 1);
            double actual = instance[2];
            double expected = mlpCopy.getOutput(input).get(0);
            TestMultilayerPerceptron.assertTrue((actual < 0.5 && expected < 0.5 || actual >= 0.5 && expected >= 0.5 ? 1 : 0) != 0);
        }
        mlpCopy.close();
    }
}

