/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.classifier.sgd;

import com.carrotsearch.randomizedtesting.annotations.ThreadLeakLingering;
import java.util.Random;
import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression;
import org.apache.mahout.classifier.sgd.L1;
import org.apache.mahout.classifier.sgd.PriorFunction;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.RandomWrapper;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.jet.random.Exponential;
import org.junit.Test;

public final class AdaptiveLogisticRegressionTest
extends MahoutTestCase {
    @ThreadLeakLingering(linger=1000)
    @Test
    public void testTrain() {
        RandomWrapper gen = RandomUtils.getRandom();
        Exponential exp = new Exponential(0.5, (Random)gen);
        DenseVector beta = new DenseVector(200);
        for (Vector.Element element : beta.all()) {
            int sign = 1;
            if (gen.nextDouble() < 0.5) {
                sign = -1;
            }
            element.set((double)sign * exp.nextDouble());
        }
        AdaptiveLogisticRegression.Wrapper cl = new AdaptiveLogisticRegression.Wrapper(2, 200, (PriorFunction)new L1());
        cl.update(new double[]{1.0E-5, 1.0});
        for (int i = 0; i < 10000; ++i) {
            AdaptiveLogisticRegression.TrainingExample r = AdaptiveLogisticRegressionTest.getExample(i, (Random)gen, (Vector)beta);
            cl.train(r);
            if (i % 1000 != 0) continue;
            System.out.printf("%10d %10.3f\n", i, cl.getLearner().auc());
        }
        AdaptiveLogisticRegressionTest.assertEquals((double)1.0, (double)cl.getLearner().auc(), (double)0.1);
        AdaptiveLogisticRegression adaptiveLogisticRegression = new AdaptiveLogisticRegression(2, 200, (PriorFunction)new L1());
        adaptiveLogisticRegression.setInterval(1000);
        for (int i = 0; i < 20000; ++i) {
            AdaptiveLogisticRegression.TrainingExample r = AdaptiveLogisticRegressionTest.getExample(i, (Random)gen, (Vector)beta);
            adaptiveLogisticRegression.train(r.getKey(), r.getActual(), r.getInstance());
            if (i % 1000 != 0 || adaptiveLogisticRegression.getBest() == null) continue;
            System.out.printf("%10d %10.4f %10.8f %.3f\n", i, adaptiveLogisticRegression.auc(), Math.log10(adaptiveLogisticRegression.getBest().getMappedParams()[0]), adaptiveLogisticRegression.getBest().getMappedParams()[1]);
        }
        AdaptiveLogisticRegressionTest.assertEquals((double)1.0, (double)adaptiveLogisticRegression.auc(), (double)0.1);
        adaptiveLogisticRegression.close();
    }

    private static AdaptiveLogisticRegression.TrainingExample getExample(int i, Random gen, Vector beta) {
        DenseVector data = new DenseVector(200);
        for (Vector.Element element : data.all()) {
            element.set(gen.nextDouble() < 0.3 ? 1.0 : 0.0);
        }
        double p = 1.0 / (1.0 + Math.exp(1.5 - data.dot(beta)));
        int target = 0;
        if (gen.nextDouble() < p) {
            target = 1;
        }
        return new AdaptiveLogisticRegression.TrainingExample((long)i, null, target, (Vector)data);
    }

    @Test
    public void copyLearnsAsExpected() {
        RandomWrapper gen = RandomUtils.getRandom();
        Exponential exp = new Exponential(0.5, (Random)gen);
        DenseVector beta = new DenseVector(200);
        for (Vector.Element element : beta.all()) {
            int sign = 1;
            if (gen.nextDouble() < 0.5) {
                sign = -1;
            }
            element.set((double)sign * exp.nextDouble());
        }
        AdaptiveLogisticRegression.Wrapper w = new AdaptiveLogisticRegression.Wrapper(2, 200, (PriorFunction)new L1());
        for (int i = 0; i < 3000; ++i) {
            AdaptiveLogisticRegression.TrainingExample r = AdaptiveLogisticRegressionTest.getExample(i, (Random)gen, (Vector)beta);
            w.train(r);
            if (i % 1000 != 0) continue;
            System.out.printf("%10d %.3f\n", i, w.getLearner().auc());
        }
        System.out.printf("%10d %.3f\n", 3000, w.getLearner().auc());
        double auc1 = w.getLearner().auc();
        AdaptiveLogisticRegression.Wrapper w2 = w.copy();
        for (int i = 0; i < 5000; ++i) {
            if (i % 1000 == 0) {
                if (i == 0) {
                    AdaptiveLogisticRegressionTest.assertEquals((String)"Should have started with no data", (double)0.5, (double)w2.getLearner().auc(), (double)1.0E-4);
                }
                if (i == 1000) {
                    double auc2 = w2.getLearner().auc();
                    AdaptiveLogisticRegressionTest.assertTrue((String)"Should have had head-start", (Math.abs(auc2 - 0.5) > 0.1 ? 1 : 0) != 0);
                    AdaptiveLogisticRegressionTest.assertTrue((String)"AUC should improve quickly on copy", (auc1 < auc2 ? 1 : 0) != 0);
                }
                System.out.printf("%10d %.3f\n", i, w2.getLearner().auc());
            }
            AdaptiveLogisticRegression.TrainingExample r = AdaptiveLogisticRegressionTest.getExample(i, (Random)gen, (Vector)beta);
            w2.train(r);
        }
        AdaptiveLogisticRegressionTest.assertEquals((String)"Original should not change after copy is updated", (double)auc1, (double)w.getLearner().auc(), (double)1.0E-5);
        AdaptiveLogisticRegressionTest.assertTrue((String)"AUC should improve significantly on copy", (auc1 < w2.getLearner().auc() - 0.05 ? 1 : 0) != 0);
        AdaptiveLogisticRegressionTest.assertEquals((double)auc1, (double)w.getLearner().auc(), (double)0.0);
    }

    @Test
    public void stepSize() {
        AdaptiveLogisticRegressionTest.assertEquals((long)500L, (long)AdaptiveLogisticRegression.stepSize((int)15000, (double)2.0));
        AdaptiveLogisticRegressionTest.assertEquals((long)2000L, (long)AdaptiveLogisticRegression.stepSize((int)15000, (double)2.6));
        AdaptiveLogisticRegressionTest.assertEquals((long)5000L, (long)AdaptiveLogisticRegression.stepSize((int)24000, (double)2.6));
        AdaptiveLogisticRegressionTest.assertEquals((long)10000L, (long)AdaptiveLogisticRegression.stepSize((int)15000, (double)3.0));
    }

    @Test
    @ThreadLeakLingering(linger=1000)
    public void constantStep() {
        AdaptiveLogisticRegression lr = new AdaptiveLogisticRegression(2, 1000, (PriorFunction)new L1());
        lr.setInterval(5000);
        AdaptiveLogisticRegressionTest.assertEquals((long)20000L, (long)lr.nextStep(15000));
        AdaptiveLogisticRegressionTest.assertEquals((long)20000L, (long)lr.nextStep(15001));
        AdaptiveLogisticRegressionTest.assertEquals((long)20000L, (long)lr.nextStep(16500));
        AdaptiveLogisticRegressionTest.assertEquals((long)20000L, (long)lr.nextStep(19999));
        lr.close();
    }

    @Test
    @ThreadLeakLingering(linger=1000)
    public void growingStep() {
        int i;
        AdaptiveLogisticRegression lr = new AdaptiveLogisticRegression(2, 1000, (PriorFunction)new L1());
        lr.setInterval(2000, 10000);
        for (i = 2000; i < 20000; i += 2000) {
            AdaptiveLogisticRegressionTest.assertEquals((long)(i + 2000), (long)lr.nextStep(i));
        }
        for (i = 20000; i < 50000; i += 5000) {
            AdaptiveLogisticRegressionTest.assertEquals((long)(i + 5000), (long)lr.nextStep(i));
        }
        for (i = 50000; i < 500000; i += 10000) {
            AdaptiveLogisticRegressionTest.assertEquals((long)(i + 10000), (long)lr.nextStep(i));
        }
        lr.close();
    }
}

