/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.classifier.sgd;

import com.google.common.base.CharMatcher;
import com.google.common.base.Charsets;
import com.google.common.base.Splitter;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.io.CharStreams;
import com.google.common.io.Resources;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.apache.mahout.classifier.AbstractVectorClassifier;
import org.apache.mahout.classifier.OnlineLearner;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.RandomWrapper;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.Functions;

public abstract class OnlineBaseTest
extends MahoutTestCase {
    private Matrix input;

    Matrix getInput() {
        return this.input;
    }

    Vector readStandardData() throws IOException {
        this.input = OnlineBaseTest.readCsv("sgd.csv");
        DenseVector target = new DenseVector(60);
        target.assign(0.0);
        target.viewPart(30, 30).assign(1.0);
        return target;
    }

    static void train(Matrix input, Vector target, OnlineLearner lr) {
        RandomUtils.useTestSeed();
        RandomWrapper gen = RandomUtils.getRandom();
        for (int row : OnlineBaseTest.permute((Random)gen, 60)) {
            lr.train((int)target.get(row), input.viewRow(row));
        }
        lr.close();
    }

    static void test(Matrix input, Vector target, AbstractVectorClassifier lr, double expected_mean_error, double expected_absolute_error) {
        Matrix tmp = lr.classify(input);
        double meanAbsoluteError = tmp.viewColumn(0).minus(target).aggregate(Functions.PLUS, Functions.ABS) / 60.0;
        double maxAbsoluteError = tmp.viewColumn(0).minus(target).aggregate(Functions.MAX, Functions.ABS);
        System.out.printf("mAE = %.4f, maxAE = %.4f\n", meanAbsoluteError, maxAbsoluteError);
        OnlineBaseTest.assertEquals((double)0.0, (double)meanAbsoluteError, (double)expected_mean_error);
        OnlineBaseTest.assertEquals((double)0.0, (double)maxAbsoluteError, (double)expected_absolute_error);
        Vector v = lr.classifyScalar(input);
        OnlineBaseTest.assertEquals((double)0.0, (double)v.minus(tmp.viewColumn(0)).norm(1.0), (double)1.0E-5);
        v = lr.classifyFull(input).viewColumn(1);
        OnlineBaseTest.assertEquals((double)0.0, (double)v.minus(tmp.viewColumn(0)).norm(1.0), (double)1.0E-4);
    }

    static int[] permute(Random gen, int max) {
        int[] permutation = new int[max];
        permutation[0] = 0;
        for (int i = 1; i < max; ++i) {
            int n = gen.nextInt(i + 1);
            if (n == i) {
                permutation[i] = i;
                continue;
            }
            permutation[i] = permutation[n];
            permutation[n] = i;
        }
        return permutation;
    }

    static Matrix readCsv(String resourceName) throws IOException {
        Splitter onCommas = Splitter.on((char)',').trimResults(CharMatcher.anyOf((CharSequence)" \""));
        InputStreamReader isr = new InputStreamReader(Resources.getResource((String)resourceName).openStream(), Charsets.UTF_8);
        List data = CharStreams.readLines((Readable)isr);
        String first = (String)data.get(0);
        data = data.subList(1, data.size());
        ArrayList values = Lists.newArrayList((Iterable)onCommas.split((CharSequence)first));
        DenseMatrix r = new DenseMatrix(data.size(), values.size());
        int column = 0;
        HashMap labels = Maps.newHashMap();
        for (String value : values) {
            labels.put(value, column);
            ++column;
        }
        r.setColumnLabelBindings((Map)labels);
        int row = 0;
        for (String line : data) {
            column = 0;
            values = Lists.newArrayList((Iterable)onCommas.split((CharSequence)line));
            for (String value : values) {
                r.set(row, column, Double.parseDouble(value));
                ++column;
            }
            ++row;
        }
        return r;
    }
}

