/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.classifier.mlp;

import java.io.File;
import org.apache.mahout.classifier.mlp.Datasets;
import org.apache.mahout.classifier.mlp.MultilayerPerceptron;
import org.apache.mahout.classifier.mlp.TrainMultilayerPerceptron;
import org.apache.mahout.common.MahoutTestCase;
import org.junit.Test;

public class TrainMultilayerPerceptronTest
extends MahoutTestCase {
    @Test
    public void testIrisDataset() throws Exception {
        String modelFileName = "mlp.model";
        File modelFile = this.getTestTempFile(modelFileName);
        File irisDataset = this.getTestTempFile("iris.csv");
        TrainMultilayerPerceptronTest.writeLines(irisDataset, Datasets.IRIS);
        String[] args = new String[]{"-i", irisDataset.getAbsolutePath(), "-sh", "-labels", "setosa", "versicolor", "virginica", "-mo", modelFile.getAbsolutePath(), "-u", "-ls", "4", "8", "3"};
        TrainMultilayerPerceptron.main((String[])args);
        TrainMultilayerPerceptronTest.assertTrue((boolean)modelFile.exists());
    }

    @Test
    public void initializeModelWithDifferentParameters() throws Exception {
        String modelFileName = "mlp.model";
        File modelFile1 = this.getTestTempFile(modelFileName);
        File irisDataset = this.getTestTempFile("iris.csv");
        TrainMultilayerPerceptronTest.writeLines(irisDataset, Datasets.IRIS);
        String[] args1 = new String[]{"-i", irisDataset.getAbsolutePath(), "-sh", "-labels", "setosa", "versicolor", "virginica", "-mo", modelFile1.getAbsolutePath(), "-u", "-ls", "4", "8", "3", "-l", "0.2", "-m", "0.35", "-r", "0.0001"};
        MultilayerPerceptron mlp1 = this.trainModel(args1, modelFile1);
        TrainMultilayerPerceptronTest.assertEquals((double)0.2, (double)mlp1.getLearningRate(), (double)1.0E-6);
        TrainMultilayerPerceptronTest.assertEquals((double)0.35, (double)mlp1.getMomentumWeight(), (double)1.0E-6);
        TrainMultilayerPerceptronTest.assertEquals((double)1.0E-4, (double)mlp1.getRegularizationWeight(), (double)1.0E-6);
        TrainMultilayerPerceptronTest.assertEquals((long)4L, (long)(mlp1.getLayerSize(0) - 1));
        TrainMultilayerPerceptronTest.assertEquals((long)8L, (long)(mlp1.getLayerSize(1) - 1));
        TrainMultilayerPerceptronTest.assertEquals((long)3L, (long)mlp1.getLayerSize(2));
        File modelFile2 = this.getTestTempFile(modelFileName);
        String[] args2 = new String[]{"-i", irisDataset.getAbsolutePath(), "-sh", "-labels", "setosa", "versicolor", "virginica", "-mo", modelFile2.getAbsolutePath(), "-ls", "4", "10", "18", "3"};
        MultilayerPerceptron mlp2 = this.trainModel(args2, modelFile2);
        TrainMultilayerPerceptronTest.assertEquals((double)0.5, (double)mlp2.getLearningRate(), (double)1.0E-6);
        TrainMultilayerPerceptronTest.assertEquals((double)0.1, (double)mlp2.getMomentumWeight(), (double)1.0E-6);
        TrainMultilayerPerceptronTest.assertEquals((double)0.0, (double)mlp2.getRegularizationWeight(), (double)1.0E-6);
        TrainMultilayerPerceptronTest.assertEquals((long)4L, (long)(mlp2.getLayerSize(0) - 1));
        TrainMultilayerPerceptronTest.assertEquals((long)10L, (long)(mlp2.getLayerSize(1) - 1));
        TrainMultilayerPerceptronTest.assertEquals((long)18L, (long)(mlp2.getLayerSize(2) - 1));
        TrainMultilayerPerceptronTest.assertEquals((long)3L, (long)mlp2.getLayerSize(3));
    }

    private MultilayerPerceptron trainModel(String[] args, File modelFile) throws Exception {
        TrainMultilayerPerceptron.main((String[])args);
        return new MultilayerPerceptron(modelFile.getAbsolutePath());
    }
}

