/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.classifier.sgd;

import com.carrotsearch.randomizedtesting.annotations.ThreadLeakLingering;
import com.google.common.io.Closeables;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.Closeable;
import java.io.DataInput;
import java.io.DataInputStream;
import java.io.DataOutput;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Random;
import org.apache.hadoop.io.Writable;
import org.apache.mahout.classifier.OnlineLearner;
import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression;
import org.apache.mahout.classifier.sgd.CrossFoldLearner;
import org.apache.mahout.classifier.sgd.L1;
import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;
import org.apache.mahout.classifier.sgd.PolymorphicWritable;
import org.apache.mahout.classifier.sgd.PriorFunction;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.RandomWrapper;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.DoubleFunction;
import org.apache.mahout.math.function.Functions;
import org.apache.mahout.math.stats.GlobalOnlineAuc;
import org.apache.mahout.math.stats.OnlineAuc;
import org.junit.Test;

public final class ModelSerializerTest
extends MahoutTestCase {
    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private static <T extends Writable> T roundTrip(T m, Class<T> clazz) throws IOException {
        ByteArrayOutputStream buf = new ByteArrayOutputStream(1000);
        DataOutputStream dos = new DataOutputStream(buf);
        try {
            PolymorphicWritable.write((DataOutput)dos, m);
        }
        finally {
            Closeables.close((Closeable)dos, (boolean)false);
        }
        return (T)PolymorphicWritable.read((DataInput)new DataInputStream(new ByteArrayInputStream(buf.toByteArray())), clazz);
    }

    @Test
    public void onlineAucRoundtrip() throws IOException {
        RandomUtils.useTestSeed();
        GlobalOnlineAuc auc1 = new GlobalOnlineAuc();
        RandomWrapper gen = RandomUtils.getRandom();
        for (int i = 0; i < 10000; ++i) {
            auc1.addSample(0, gen.nextGaussian());
            auc1.addSample(1, gen.nextGaussian() + 1.0);
        }
        ModelSerializerTest.assertEquals((double)0.76, (double)auc1.auc(), (double)0.01);
        OnlineAuc auc3 = ModelSerializerTest.roundTrip(auc1, OnlineAuc.class);
        ModelSerializerTest.assertEquals((double)auc1.auc(), (double)auc3.auc(), (double)0.0);
        for (int i = 0; i < 1000; ++i) {
            auc1.addSample(0, gen.nextGaussian());
            auc1.addSample(1, gen.nextGaussian() + 1.0);
            auc3.addSample(0, gen.nextGaussian());
            auc3.addSample(1, gen.nextGaussian() + 1.0);
        }
        ModelSerializerTest.assertEquals((double)auc1.auc(), (double)auc3.auc(), (double)0.01);
    }

    @Test
    public void onlineLogisticRegressionRoundTrip() throws IOException {
        OnlineLogisticRegression olr = new OnlineLogisticRegression(2, 5, (PriorFunction)new L1());
        ModelSerializerTest.train((OnlineLearner)olr, 100);
        OnlineLogisticRegression olr3 = ModelSerializerTest.roundTrip(olr, OnlineLogisticRegression.class);
        ModelSerializerTest.assertEquals((double)0.0, (double)olr.getBeta().minus(olr3.getBeta()).aggregate(Functions.MAX, Functions.IDENTITY), (double)1.0E-6);
        ModelSerializerTest.train((OnlineLearner)olr, 100);
        ModelSerializerTest.train((OnlineLearner)olr3, 100);
        ModelSerializerTest.assertEquals((double)0.0, (double)olr.getBeta().minus(olr3.getBeta()).aggregate(Functions.MAX, Functions.IDENTITY), (double)1.0E-6);
        olr.close();
        olr3.close();
    }

    @Test
    public void crossFoldLearnerRoundTrip() throws IOException {
        CrossFoldLearner learner = new CrossFoldLearner(5, 2, 5, (PriorFunction)new L1());
        ModelSerializerTest.train((OnlineLearner)learner, 100);
        CrossFoldLearner olr3 = ModelSerializerTest.roundTrip(learner, CrossFoldLearner.class);
        double auc1 = learner.auc();
        ModelSerializerTest.assertTrue((auc1 > 0.85 ? 1 : 0) != 0);
        ModelSerializerTest.assertEquals((double)auc1, (double)learner.auc(), (double)1.0E-6);
        ModelSerializerTest.assertEquals((double)auc1, (double)olr3.auc(), (double)1.0E-6);
        ModelSerializerTest.train((OnlineLearner)learner, 100);
        ModelSerializerTest.train((OnlineLearner)learner, 100);
        ModelSerializerTest.train((OnlineLearner)olr3, 100);
        ModelSerializerTest.assertEquals((double)learner.auc(), (double)learner.auc(), (double)0.02);
        ModelSerializerTest.assertEquals((double)learner.auc(), (double)olr3.auc(), (double)0.02);
        double auc2 = learner.auc();
        ModelSerializerTest.assertTrue((auc2 > auc1 ? 1 : 0) != 0);
        learner.close();
        olr3.close();
    }

    @ThreadLeakLingering(linger=1000)
    @Test
    public void adaptiveLogisticRegressionRoundTrip() throws IOException {
        AdaptiveLogisticRegression learner = new AdaptiveLogisticRegression(2, 5, (PriorFunction)new L1());
        learner.setInterval(200);
        ModelSerializerTest.train((OnlineLearner)learner, 400);
        AdaptiveLogisticRegression olr3 = ModelSerializerTest.roundTrip(learner, AdaptiveLogisticRegression.class);
        double auc1 = learner.auc();
        ModelSerializerTest.assertTrue((auc1 > 0.85 ? 1 : 0) != 0);
        ModelSerializerTest.assertEquals((double)auc1, (double)learner.auc(), (double)1.0E-6);
        ModelSerializerTest.assertEquals((double)auc1, (double)olr3.auc(), (double)1.0E-6);
        ModelSerializerTest.train((OnlineLearner)learner, 1000);
        ModelSerializerTest.train((OnlineLearner)learner, 1000);
        ModelSerializerTest.train((OnlineLearner)olr3, 1000);
        ModelSerializerTest.assertEquals((double)learner.auc(), (double)learner.auc(), (double)0.005);
        ModelSerializerTest.assertEquals((double)learner.auc(), (double)olr3.auc(), (double)0.005);
        double auc2 = learner.auc();
        ModelSerializerTest.assertTrue((String)String.format("%.3f > %.3f", auc2, auc1), (auc2 > auc1 ? 1 : 0) != 0);
        learner.close();
        olr3.close();
    }

    private static void train(OnlineLearner olr, int n) {
        DenseVector beta = new DenseVector(new double[]{1.0, -1.0, 0.0, 0.5, -0.5});
        RandomWrapper gen = RandomUtils.getRandom();
        for (int i = 0; i < n; ++i) {
            Vector x = ModelSerializerTest.randomVector((Random)gen, 5);
            int target = gen.nextDouble() < beta.dot(x) ? 1 : 0;
            olr.train(target, x);
        }
    }

    private static Vector randomVector(final Random gen, int n) {
        DenseVector x = new DenseVector(n);
        x.assign(new DoubleFunction(){

            public double apply(double v) {
                return gen.nextGaussian();
            }
        });
        return x;
    }
}

