/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.classifier.sgd;

import com.google.common.base.Charsets;
import com.google.common.base.Splitter;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.io.Resources;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInput;
import java.io.DataInputStream;
import java.io.DataOutput;
import java.io.DataOutputStream;
import java.io.IOException;
import java.lang.reflect.Field;
import java.net.URL;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import org.apache.hadoop.io.Writable;
import org.apache.mahout.classifier.AbstractVectorClassifier;
import org.apache.mahout.classifier.OnlineLearner;
import org.apache.mahout.classifier.sgd.CrossFoldLearner;
import org.apache.mahout.classifier.sgd.L1;
import org.apache.mahout.classifier.sgd.L2;
import org.apache.mahout.classifier.sgd.OnlineBaseTest;
import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;
import org.apache.mahout.classifier.sgd.PolymorphicWritable;
import org.apache.mahout.classifier.sgd.PriorFunction;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.RandomWrapper;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.vectorizer.encoders.Dictionary;
import org.junit.Assert;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class OnlineLogisticRegressionTest
extends OnlineBaseTest {
    private static final Logger logger = LoggerFactory.getLogger(OnlineLogisticRegressionTest.class);

    @Test
    public void crossValidation() throws IOException {
        Vector target = this.readStandardData();
        CrossFoldLearner lr = new CrossFoldLearner(5, 2, 8, (PriorFunction)new L1()).lambda(0.001).learningRate(50.0);
        OnlineLogisticRegressionTest.train(this.getInput(), target, (OnlineLearner)lr);
        System.out.printf("%.2f %.5f\n", lr.auc(), lr.logLikelihood());
        OnlineLogisticRegressionTest.test(this.getInput(), target, (AbstractVectorClassifier)lr, 0.05, 0.3);
    }

    @Test
    public void crossValidatedAuc() throws IOException {
        RandomUtils.useTestSeed();
        RandomWrapper gen = RandomUtils.getRandom();
        Matrix data = OnlineLogisticRegressionTest.readCsv("cancer.csv");
        CrossFoldLearner lr = new CrossFoldLearner(5, 2, 10, (PriorFunction)new L1()).stepOffset(10).decayExponent(0.7).lambda(0.001).learningRate(5.0);
        int k = 0;
        int[] ordering = OnlineLogisticRegressionTest.permute((Random)gen, data.numRows());
        for (int epoch = 0; epoch < 100; ++epoch) {
            for (int row : ordering) {
                lr.train((long)row, (int)data.get(row, 9), data.viewRow(row));
                System.out.printf("%d,%d,%.3f\n", epoch, k++, lr.auc());
            }
            OnlineLogisticRegressionTest.assertEquals((double)1.0, (double)lr.auc(), (double)0.2);
        }
        OnlineLogisticRegressionTest.assertEquals((double)1.0, (double)lr.auc(), (double)0.1);
    }

    @Test
    public void testClassify() {
        OnlineLogisticRegression lr = new OnlineLogisticRegression(3, 2, (PriorFunction)new L2(1.0));
        lr.setBeta(0, 0, -1.0);
        lr.setBeta(1, 0, -2.0);
        Vector v = lr.classify((Vector)new DenseVector(new double[]{0.0, 0.0}));
        OnlineLogisticRegressionTest.assertEquals((double)0.3333333333333333, (double)v.get(0), (double)1.0E-8);
        OnlineLogisticRegressionTest.assertEquals((double)0.3333333333333333, (double)v.get(1), (double)1.0E-8);
        v = lr.classifyFull((Vector)new DenseVector(new double[]{0.0, 0.0}));
        OnlineLogisticRegressionTest.assertEquals((double)1.0, (double)v.zSum(), (double)1.0E-8);
        OnlineLogisticRegressionTest.assertEquals((double)0.3333333333333333, (double)v.get(0), (double)1.0E-8);
        OnlineLogisticRegressionTest.assertEquals((double)0.3333333333333333, (double)v.get(1), (double)1.0E-8);
        OnlineLogisticRegressionTest.assertEquals((double)0.3333333333333333, (double)v.get(2), (double)1.0E-8);
        v = lr.classify((Vector)new DenseVector(new double[]{0.0, 1.0}));
        OnlineLogisticRegressionTest.assertEquals((double)0.3333333333333333, (double)v.get(0), (double)0.001);
        OnlineLogisticRegressionTest.assertEquals((double)0.3333333333333333, (double)v.get(1), (double)0.001);
        v = lr.classifyFull((Vector)new DenseVector(new double[]{0.0, 1.0}));
        OnlineLogisticRegressionTest.assertEquals((double)1.0, (double)v.zSum(), (double)1.0E-8);
        OnlineLogisticRegressionTest.assertEquals((double)0.3333333333333333, (double)v.get(0), (double)0.001);
        OnlineLogisticRegressionTest.assertEquals((double)0.3333333333333333, (double)v.get(1), (double)0.001);
        OnlineLogisticRegressionTest.assertEquals((double)0.3333333333333333, (double)v.get(2), (double)0.001);
        v = lr.classify((Vector)new DenseVector(new double[]{1.0, 0.0}));
        OnlineLogisticRegressionTest.assertEquals((double)(Math.exp(-1.0) / (1.0 + Math.exp(-1.0) + Math.exp(-2.0))), (double)v.get(0), (double)1.0E-8);
        OnlineLogisticRegressionTest.assertEquals((double)(Math.exp(-2.0) / (1.0 + Math.exp(-1.0) + Math.exp(-2.0))), (double)v.get(1), (double)1.0E-8);
        v = lr.classifyFull((Vector)new DenseVector(new double[]{1.0, 0.0}));
        OnlineLogisticRegressionTest.assertEquals((double)1.0, (double)v.zSum(), (double)1.0E-8);
        OnlineLogisticRegressionTest.assertEquals((double)(1.0 / (1.0 + Math.exp(-1.0) + Math.exp(-2.0))), (double)v.get(0), (double)1.0E-8);
        OnlineLogisticRegressionTest.assertEquals((double)(Math.exp(-1.0) / (1.0 + Math.exp(-1.0) + Math.exp(-2.0))), (double)v.get(1), (double)1.0E-8);
        OnlineLogisticRegressionTest.assertEquals((double)(Math.exp(-2.0) / (1.0 + Math.exp(-1.0) + Math.exp(-2.0))), (double)v.get(2), (double)1.0E-8);
        lr.setBeta(0, 1, 1.0);
        v = lr.classifyFull((Vector)new DenseVector(new double[]{1.0, 1.0}));
        OnlineLogisticRegressionTest.assertEquals((double)1.0, (double)v.zSum(), (double)1.0E-8);
        OnlineLogisticRegressionTest.assertEquals((double)(Math.exp(0.0) / (1.0 + Math.exp(0.0) + Math.exp(-2.0))), (double)v.get(1), (double)0.001);
        OnlineLogisticRegressionTest.assertEquals((double)(Math.exp(-2.0) / (1.0 + Math.exp(0.0) + Math.exp(-2.0))), (double)v.get(2), (double)0.001);
        OnlineLogisticRegressionTest.assertEquals((double)(1.0 / (1.0 + Math.exp(0.0) + Math.exp(-2.0))), (double)v.get(0), (double)0.001);
        lr.setBeta(1, 1, 3.0);
        v = lr.classifyFull((Vector)new DenseVector(new double[]{1.0, 1.0}));
        OnlineLogisticRegressionTest.assertEquals((double)1.0, (double)v.zSum(), (double)1.0E-8);
        OnlineLogisticRegressionTest.assertEquals((double)(Math.exp(0.0) / (1.0 + Math.exp(0.0) + Math.exp(1.0))), (double)v.get(1), (double)1.0E-8);
        OnlineLogisticRegressionTest.assertEquals((double)(Math.exp(1.0) / (1.0 + Math.exp(0.0) + Math.exp(1.0))), (double)v.get(2), (double)1.0E-8);
        OnlineLogisticRegressionTest.assertEquals((double)(1.0 / (1.0 + Math.exp(0.0) + Math.exp(1.0))), (double)v.get(0), (double)1.0E-8);
    }

    @Test
    public void iris() throws IOException {
        RandomUtils.useTestSeed();
        Splitter onComma = Splitter.on((String)",");
        List raw = Resources.readLines((URL)Resources.getResource((String)"iris.csv"), (Charset)Charsets.UTF_8);
        ArrayList data = Lists.newArrayList();
        ArrayList target = Lists.newArrayList();
        Dictionary dict = new Dictionary();
        ArrayList order = Lists.newArrayList();
        for (String line : raw.subList(1, raw.size())) {
            order.add(order.size());
            DenseVector v = new DenseVector(5);
            v.set(0, 1.0);
            int i = 1;
            Iterable values = onComma.split((CharSequence)line);
            for (String value : Iterables.limit((Iterable)values, (int)4)) {
                v.set(i++, Double.parseDouble(value));
            }
            data.add(v);
            target.add(dict.intern((String)Iterables.get((Iterable)values, (int)4)));
        }
        RandomWrapper random = RandomUtils.getRandom();
        Collections.shuffle(order, (Random)random);
        List train = order.subList(0, 100);
        List test = order.subList(100, 150);
        logger.warn("Training set = {}", train);
        logger.warn("Test set = {}", test);
        int[] correct = new int[test.size() + 1];
        for (int run = 0; run < 200; ++run) {
            OnlineLogisticRegression lr = new OnlineLogisticRegression(3, 5, (PriorFunction)new L2(1.0));
            for (int pass = 0; pass < 30; ++pass) {
                Collections.shuffle(train, (Random)random);
                Iterator i$ = train.iterator();
                while (i$.hasNext()) {
                    int k = (Integer)i$.next();
                    lr.train(((Integer)target.get(k)).intValue(), (Vector)data.get(k));
                }
            }
            int x = 0;
            int[] count = new int[3];
            for (Integer k : test) {
                int r;
                int n = r = lr.classifyFull((Vector)data.get(k)).maxValueIndex();
                count[n] = count[n] + 1;
                x += r == (Integer)target.get(k) ? 1 : 0;
            }
            int n = x;
            correct[n] = correct[n] + 1;
        }
        int i = 0;
        while ((double)i < Math.floor(0.95 * (double)test.size())) {
            OnlineLogisticRegressionTest.assertEquals((String)String.format("%d trials had unacceptable accuracy of only %.0f%%: ", correct[i], 100.0 * (double)i / (double)test.size()), (long)0L, (long)correct[i]);
            ++i;
        }
        OnlineLogisticRegressionTest.assertEquals((String)String.format("%d trials had unrealistic accuracy of 100%%", correct[test.size() - 1]), (long)0L, (long)correct[test.size()]);
    }

    @Test
    public void testTrain() throws Exception {
        Vector target = this.readStandardData();
        OnlineLogisticRegression lr = new OnlineLogisticRegression(2, 8, (PriorFunction)new L1()).lambda(0.001).learningRate(50.0);
        OnlineLogisticRegressionTest.train(this.getInput(), target, (OnlineLearner)lr);
        OnlineLogisticRegressionTest.test(this.getInput(), target, (AbstractVectorClassifier)lr, 0.05, 0.3);
    }

    @Test
    public void testSerializationAndDeSerialization() throws Exception {
        OnlineLogisticRegression read;
        byte[] output;
        OnlineLogisticRegression lr = new OnlineLogisticRegression(2, 8, (PriorFunction)new L1()).lambda(0.001).stepOffset(11).alpha(0.01).learningRate(50.0).decayExponent(-0.02);
        lr.close();
        try (ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
             DataOutputStream dataOutputStream = new DataOutputStream(byteArrayOutputStream);){
            PolymorphicWritable.write((DataOutput)dataOutputStream, (Writable)lr);
            output = byteArrayOutputStream.toByteArray();
        }
        try (ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(output);
             DataInputStream dataInputStream = new DataInputStream(byteArrayInputStream);){
            read = (OnlineLogisticRegression)PolymorphicWritable.read((DataInput)dataInputStream, OnlineLogisticRegression.class);
        }
        Assert.assertEquals((double)0.001, (double)read.getLambda(), (double)1.0E-7);
        Field stepOffset = lr.getClass().getDeclaredField("stepOffset");
        stepOffset.setAccessible(true);
        int stepOffsetVal = (Integer)stepOffset.get(lr);
        Assert.assertEquals((long)11L, (long)stepOffsetVal);
        Field decayFactor = lr.getClass().getDeclaredField("decayFactor");
        decayFactor.setAccessible(true);
        double decayFactorVal = (Double)decayFactor.get(lr);
        Assert.assertEquals((double)0.01, (double)decayFactorVal, (double)1.0E-7);
        Field mu0 = lr.getClass().getDeclaredField("mu0");
        mu0.setAccessible(true);
        double mu0Val = (Double)mu0.get(lr);
        Assert.assertEquals((double)50.0, (double)mu0Val, (double)1.0E-7);
        Field forgettingExponent = lr.getClass().getDeclaredField("forgettingExponent");
        forgettingExponent.setAccessible(true);
        double forgettingExponentVal = (Double)forgettingExponent.get(lr);
        Assert.assertEquals((double)-0.02, (double)forgettingExponentVal, (double)1.0E-7);
    }
}

