/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.classifier.mlp;

import com.google.common.base.Charsets;
import com.google.common.collect.Lists;
import com.google.common.io.Files;
import java.io.File;
import java.io.IOException;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.apache.commons.csv.CSVUtils;
import org.apache.mahout.classifier.mlp.Datasets;
import org.apache.mahout.classifier.mlp.MultilayerPerceptron;
import org.apache.mahout.classifier.mlp.NeuralNetwork;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.junit.Test;

public class TestNeuralNetwork
extends MahoutTestCase {
    @Test
    public void testReadWrite() throws IOException {
        MultilayerPerceptron ann = new MultilayerPerceptron();
        ann.addLayer(2, false, "Identity");
        ann.addLayer(5, false, "Identity");
        ann.addLayer(1, true, "Identity");
        ann.setCostFunction("Minus_Squared");
        double learningRate = 0.2;
        double momentumWeight = 0.5;
        double regularizationWeight = 0.05;
        ann.setLearningRate(learningRate).setMomentumWeight(momentumWeight).setRegularizationWeight(regularizationWeight);
        DenseMatrix[] matrices = new DenseMatrix[2];
        matrices[0] = new DenseMatrix(5, 3);
        matrices[0].assign(0.2);
        matrices[1] = new DenseMatrix(1, 6);
        matrices[1].assign(0.8);
        ann.setWeightMatrices((Matrix[])matrices);
        String modelFilename = "testNeuralNetworkReadWrite";
        File tmpModelFile = this.getTestTempFile(modelFilename);
        ann.setModelPath(tmpModelFile.getAbsolutePath());
        ann.writeModelToFile();
        MultilayerPerceptron annCopy = new MultilayerPerceptron(tmpModelFile.getAbsolutePath());
        TestNeuralNetwork.assertEquals((Object)annCopy.getClass().getSimpleName(), (Object)annCopy.getModelType());
        TestNeuralNetwork.assertEquals((Object)tmpModelFile.getAbsolutePath(), (Object)annCopy.getModelPath());
        TestNeuralNetwork.assertEquals((double)learningRate, (double)annCopy.getLearningRate(), (double)1.0E-6);
        TestNeuralNetwork.assertEquals((double)momentumWeight, (double)annCopy.getMomentumWeight(), (double)1.0E-6);
        TestNeuralNetwork.assertEquals((double)regularizationWeight, (double)annCopy.getRegularizationWeight(), (double)1.0E-6);
        TestNeuralNetwork.assertEquals((Object)NeuralNetwork.TrainingMethod.GRADIENT_DESCENT, (Object)annCopy.getTrainingMethod());
        Matrix[] weightsMatrices = annCopy.getWeightMatrices();
        for (int i = 0; i < weightsMatrices.length; ++i) {
            DenseMatrix expectMat = matrices[i];
            Matrix actualMat = weightsMatrices[i];
            for (int j = 0; j < expectMat.rowSize(); ++j) {
                for (int k = 0; k < expectMat.columnSize(); ++k) {
                    TestNeuralNetwork.assertEquals((double)expectMat.get(j, k), (double)actualMat.get(j, k), (double)1.0E-6);
                }
            }
        }
    }

    @Test
    public void testOutput() {
        MultilayerPerceptron ann = new MultilayerPerceptron();
        ann.addLayer(2, false, "Identity");
        ann.addLayer(5, false, "Identity");
        ann.addLayer(1, true, "Identity");
        ann.setCostFunction("Minus_Squared").setLearningRate(0.1);
        Matrix[] matrices = new Matrix[2];
        matrices[0] = new DenseMatrix(5, 3);
        matrices[0].assign(0.5);
        matrices[1] = new DenseMatrix(1, 6);
        matrices[1].assign(0.5);
        ann.setWeightMatrices(matrices);
        double[] arr = new double[]{0.0, 1.0};
        DenseVector training = new DenseVector(arr);
        Vector result = ann.getOutput((Vector)training);
        TestNeuralNetwork.assertEquals((long)1L, (long)result.size());
        MultilayerPerceptron ann2 = new MultilayerPerceptron();
        ann2.addLayer(2, false, "Sigmoid");
        ann2.addLayer(3, false, "Sigmoid");
        ann2.addLayer(1, true, "Sigmoid");
        ann2.setCostFunction("Minus_Squared");
        ann2.setLearningRate(0.3);
        Matrix[] matrices2 = new Matrix[2];
        matrices2[0] = new DenseMatrix(3, 3);
        matrices2[0].assign(0.5);
        matrices2[1] = new DenseMatrix(1, 4);
        matrices2[1].assign(0.5);
        ann2.setWeightMatrices(matrices2);
        double[] test = new double[]{0.0, 0.0};
        double[] result2 = new double[]{0.807476};
        Vector vec = ann2.getOutput((Vector)new DenseVector(test));
        double[] arrVec = new double[vec.size()];
        for (int i = 0; i < arrVec.length; ++i) {
            arrVec[i] = vec.getQuick(i);
        }
        TestNeuralNetwork.assertArrayEquals((double[])result2, (double[])arrVec, (double)1.0E-6);
        MultilayerPerceptron ann3 = new MultilayerPerceptron();
        ann3.addLayer(2, false, "Sigmoid");
        ann3.addLayer(3, false, "Sigmoid");
        ann3.addLayer(1, true, "Sigmoid");
        ann3.setCostFunction("Minus_Squared").setLearningRate(0.3);
        Matrix[] initMatrices = new Matrix[2];
        initMatrices[0] = new DenseMatrix(3, 3);
        initMatrices[0].assign(0.5);
        initMatrices[1] = new DenseMatrix(1, 4);
        initMatrices[1].assign(0.5);
        ann3.setWeightMatrices(initMatrices);
        double[] instance = new double[]{0.0, 1.0};
        Vector output = ann3.getOutput((Vector)new DenseVector(instance));
        TestNeuralNetwork.assertEquals((double)0.831541, (double)output.get(0), (double)1.0E-6);
    }

    @Test
    public void testNeuralNetwork() throws IOException {
        this.testNeuralNetwork("testNeuralNetworkXORLocal", false, false, 10000);
        this.testNeuralNetwork("testNeuralNetworkXORWithMomentum", true, false, 5000);
        this.testNeuralNetwork("testNeuralNetworkXORWithRegularization", true, true, 5000);
    }

    private void testNeuralNetwork(String modelFilename, boolean useMomentum, boolean useRegularization, int iterations) throws IOException {
        MultilayerPerceptron ann = new MultilayerPerceptron();
        ann.addLayer(2, false, "Sigmoid");
        ann.addLayer(3, false, "Sigmoid");
        ann.addLayer(1, true, "Sigmoid");
        ann.setCostFunction("Minus_Squared").setLearningRate(0.1);
        if (useMomentum) {
            ann.setMomentumWeight(0.6);
        }
        if (useRegularization) {
            ann.setRegularizationWeight(0.01);
        }
        double[][] instances = new double[][]{{0.0, 1.0, 1.0}, {0.0, 0.0, 0.0}, {1.0, 0.0, 1.0}, {1.0, 1.0, 0.0}};
        for (int i = 0; i < iterations; ++i) {
            for (double[] instance : instances) {
                ann.trainOnline((Vector)new DenseVector(instance));
            }
        }
        for (double[] instance : instances) {
            Vector input = new DenseVector(instance).viewPart(0, instance.length - 1);
            double actual = instance[2];
            double expected = ann.getOutput(input).get(0);
            TestNeuralNetwork.assertTrue((actual < 0.5 && expected < 0.5 || actual >= 0.5 && expected >= 0.5 ? 1 : 0) != 0);
        }
        File tmpModelFile = this.getTestTempFile(modelFilename);
        ann.setModelPath(tmpModelFile.getAbsolutePath());
        ann.writeModelToFile();
        MultilayerPerceptron annCopy = new MultilayerPerceptron(tmpModelFile.getAbsolutePath());
        for (double[] instance : instances) {
            Vector input = new DenseVector(instance).viewPart(0, instance.length - 1);
            double actual = instance[2];
            double expected = annCopy.getOutput(input).get(0);
            TestNeuralNetwork.assertTrue((actual < 0.5 && expected < 0.5 || actual >= 0.5 && expected >= 0.5 ? 1 : 0) != 0);
        }
    }

    @Test
    public void testWithCancerDataSet() throws IOException {
        File cancerDataset = this.getTestTempFile("cancer.csv");
        TestNeuralNetwork.writeLines(cancerDataset, Datasets.CANCER);
        ArrayList records = Lists.newArrayList();
        List cancerDataSetList = Files.readLines((File)cancerDataset, (Charset)Charsets.UTF_8);
        cancerDataSetList.remove(0);
        for (String line : cancerDataSetList) {
            String[] tokens = CSVUtils.parseLine((String)line);
            double[] values = new double[tokens.length];
            for (int i = 0; i < tokens.length; ++i) {
                values[i] = Double.parseDouble(tokens[i]);
            }
            records.add(new DenseVector(values));
        }
        int splitPoint = (int)((double)records.size() * 0.8);
        List trainingSet = records.subList(0, splitPoint);
        List testSet = records.subList(splitPoint, records.size());
        MultilayerPerceptron ann = new MultilayerPerceptron();
        int featureDimension = ((Vector)records.get(0)).size() - 1;
        ann.addLayer(featureDimension, false, "Sigmoid");
        ann.addLayer(featureDimension * 2, false, "Sigmoid");
        ann.addLayer(1, true, "Sigmoid");
        ann.setLearningRate(0.05).setMomentumWeight(0.5).setRegularizationWeight(0.001);
        int iteration = 2000;
        for (int i = 0; i < iteration; ++i) {
            for (Vector trainingInstance : trainingSet) {
                ann.trainOnline(trainingInstance);
            }
        }
        int correctInstances = 0;
        for (Vector testInstance : testSet) {
            double expected;
            Vector res = ann.getOutput(testInstance.viewPart(0, testInstance.size() - 1));
            double actual = res.get(0);
            if (!(Math.abs(actual - (expected = testInstance.get(testInstance.size() - 1))) <= 0.1)) continue;
            ++correctInstances;
        }
        double accuracy = (double)correctInstances / (double)testSet.size() * 100.0;
        TestNeuralNetwork.assertTrue((String)"The classifier is even worse than a random guesser!", (accuracy > 50.0 ? 1 : 0) != 0);
        System.out.printf("Cancer DataSet. Classification precision: %d/%d = %f%%\n", correctInstances, testSet.size(), accuracy);
    }

    @Test
    public void testWithIrisDataSet() throws IOException {
        File irisDataset = this.getTestTempFile("iris.csv");
        TestNeuralNetwork.writeLines(irisDataset, Datasets.IRIS);
        int numOfClasses = 3;
        ArrayList records = Lists.newArrayList();
        List irisDataSetList = Files.readLines((File)irisDataset, (Charset)Charsets.UTF_8);
        irisDataSetList.remove(0);
        for (String line : irisDataSetList) {
            String[] tokens = CSVUtils.parseLine((String)line);
            double[] values = new double[tokens.length + numOfClasses - 1];
            Arrays.fill(values, 0.0);
            for (int i = 0; i < tokens.length - 1; ++i) {
                values[i] = Double.parseDouble(tokens[i]);
            }
            String label = tokens[tokens.length - 1];
            if (label.equalsIgnoreCase("setosa")) {
                values[values.length - 3] = 1.0;
            } else if (label.equalsIgnoreCase("versicolor")) {
                values[values.length - 2] = 1.0;
            } else {
                values[values.length - 1] = 1.0;
            }
            records.add(new DenseVector(values));
        }
        Collections.shuffle(records);
        int splitPoint = (int)((double)records.size() * 0.8);
        List trainingSet = records.subList(0, splitPoint);
        List testSet = records.subList(splitPoint, records.size());
        MultilayerPerceptron ann = new MultilayerPerceptron();
        int featureDimension = ((Vector)records.get(0)).size() - numOfClasses;
        ann.addLayer(featureDimension, false, "Sigmoid");
        ann.addLayer(featureDimension * 2, false, "Sigmoid");
        ann.addLayer(3, true, "Sigmoid");
        ann.setLearningRate(0.05).setMomentumWeight(0.4).setRegularizationWeight(0.005);
        int iteration = 2000;
        for (int i = 0; i < iteration; ++i) {
            for (Vector trainingInstance : trainingSet) {
                ann.trainOnline(trainingInstance);
            }
        }
        int correctInstances = 0;
        for (Vector testInstance : testSet) {
            Vector res = ann.getOutput(testInstance.viewPart(0, testInstance.size() - numOfClasses));
            double[] actualLabels = new double[numOfClasses];
            for (int i = 0; i < numOfClasses; ++i) {
                actualLabels[i] = res.get(i);
            }
            double[] expectedLabels = new double[numOfClasses];
            for (int i = 0; i < numOfClasses; ++i) {
                expectedLabels[i] = testInstance.get(testInstance.size() - numOfClasses + i);
            }
            boolean allCorrect = true;
            for (int i = 0; i < numOfClasses; ++i) {
                if (!(Math.abs(expectedLabels[i] - actualLabels[i]) >= 0.1)) continue;
                allCorrect = false;
                break;
            }
            if (!allCorrect) continue;
            ++correctInstances;
        }
        double accuracy = (double)correctInstances / (double)testSet.size() * 100.0;
        TestNeuralNetwork.assertTrue((String)"The model is even worse than a random guesser.", (accuracy > 50.0 ? 1 : 0) != 0);
        System.out.printf("Iris DataSet. Classification precision: %d/%d = %f%%\n", correctInstances, testSet.size(), accuracy);
    }
}

