package org.apache.mahout.classifier.sgd;

import com.google.common.base.Charsets;
import com.google.common.base.Splitter;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.io.Closer;
import com.google.common.io.Resources;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.RandomWrapper;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.vectorizer.encoders.Dictionary;
import org.junit.Assert;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.class */
public final class OnlineLogisticRegressionTest extends OnlineBaseTest {
    private static final Logger logger = LoggerFactory.getLogger(OnlineLogisticRegressionTest.class);

    @Test
    public void crossValidation() throws IOException {
        Vector readStandardData = readStandardData();
        CrossFoldLearner learningRate = new CrossFoldLearner(5, 2, 8, new L1()).lambda(0.001d).learningRate(50.0d);
        train(getInput(), readStandardData, learningRate);
        System.out.printf("%.2f %.5f\n", Double.valueOf(learningRate.auc()), Double.valueOf(learningRate.logLikelihood()));
        test(getInput(), readStandardData, learningRate, 0.05d, 0.3d);
    }

    @Test
    public void crossValidatedAuc() throws IOException {
        RandomUtils.useTestSeed();
        RandomWrapper random = RandomUtils.getRandom();
        Matrix readCsv = readCsv("cancer.csv");
        CrossFoldLearner learningRate = new CrossFoldLearner(5, 2, 10, new L1()).stepOffset(10).decayExponent(0.7d).lambda(0.001d).learningRate(5.0d);
        int i = 0;
        int[] permute = permute(random, readCsv.numRows());
        for (int i2 = 0; i2 < 100; i2++) {
            for (int i3 : permute) {
                learningRate.train(i3, (int) readCsv.get(i3, 9), readCsv.viewRow(i3));
                int i4 = i;
                i++;
                System.out.printf("%d,%d,%.3f\n", Integer.valueOf(i2), Integer.valueOf(i4), Double.valueOf(learningRate.auc()));
            }
            assertEquals(1.0d, learningRate.auc(), 0.2d);
        }
        assertEquals(1.0d, learningRate.auc(), 0.1d);
    }

    @Test
    public void testClassify() {
        OnlineLogisticRegression onlineLogisticRegression = new OnlineLogisticRegression(3, 2, new L2(1.0d));
        onlineLogisticRegression.setBeta(0, 0, -1.0d);
        onlineLogisticRegression.setBeta(1, 0, -2.0d);
        Vector classify = onlineLogisticRegression.classify(new DenseVector(new double[]{0.0d, 0.0d}));
        assertEquals(0.3333333333333333d, classify.get(0), 1.0E-8d);
        assertEquals(0.3333333333333333d, classify.get(1), 1.0E-8d);
        Vector classifyFull = onlineLogisticRegression.classifyFull(new DenseVector(new double[]{0.0d, 0.0d}));
        assertEquals(1.0d, classifyFull.zSum(), 1.0E-8d);
        assertEquals(0.3333333333333333d, classifyFull.get(0), 1.0E-8d);
        assertEquals(0.3333333333333333d, classifyFull.get(1), 1.0E-8d);
        assertEquals(0.3333333333333333d, classifyFull.get(2), 1.0E-8d);
        Vector classify2 = onlineLogisticRegression.classify(new DenseVector(new double[]{0.0d, 1.0d}));
        assertEquals(0.3333333333333333d, classify2.get(0), 0.001d);
        assertEquals(0.3333333333333333d, classify2.get(1), 0.001d);
        Vector classifyFull2 = onlineLogisticRegression.classifyFull(new DenseVector(new double[]{0.0d, 1.0d}));
        assertEquals(1.0d, classifyFull2.zSum(), 1.0E-8d);
        assertEquals(0.3333333333333333d, classifyFull2.get(0), 0.001d);
        assertEquals(0.3333333333333333d, classifyFull2.get(1), 0.001d);
        assertEquals(0.3333333333333333d, classifyFull2.get(2), 0.001d);
        Vector classify3 = onlineLogisticRegression.classify(new DenseVector(new double[]{1.0d, 0.0d}));
        assertEquals(Math.exp(-1.0d) / ((1.0d + Math.exp(-1.0d)) + Math.exp(-2.0d)), classify3.get(0), 1.0E-8d);
        assertEquals(Math.exp(-2.0d) / ((1.0d + Math.exp(-1.0d)) + Math.exp(-2.0d)), classify3.get(1), 1.0E-8d);
        Vector classifyFull3 = onlineLogisticRegression.classifyFull(new DenseVector(new double[]{1.0d, 0.0d}));
        assertEquals(1.0d, classifyFull3.zSum(), 1.0E-8d);
        assertEquals(1.0d / ((1.0d + Math.exp(-1.0d)) + Math.exp(-2.0d)), classifyFull3.get(0), 1.0E-8d);
        assertEquals(Math.exp(-1.0d) / ((1.0d + Math.exp(-1.0d)) + Math.exp(-2.0d)), classifyFull3.get(1), 1.0E-8d);
        assertEquals(Math.exp(-2.0d) / ((1.0d + Math.exp(-1.0d)) + Math.exp(-2.0d)), classifyFull3.get(2), 1.0E-8d);
        onlineLogisticRegression.setBeta(0, 1, 1.0d);
        Vector classifyFull4 = onlineLogisticRegression.classifyFull(new DenseVector(new double[]{1.0d, 1.0d}));
        assertEquals(1.0d, classifyFull4.zSum(), 1.0E-8d);
        assertEquals(Math.exp(0.0d) / ((1.0d + Math.exp(0.0d)) + Math.exp(-2.0d)), classifyFull4.get(1), 0.001d);
        assertEquals(Math.exp(-2.0d) / ((1.0d + Math.exp(0.0d)) + Math.exp(-2.0d)), classifyFull4.get(2), 0.001d);
        assertEquals(1.0d / ((1.0d + Math.exp(0.0d)) + Math.exp(-2.0d)), classifyFull4.get(0), 0.001d);
        onlineLogisticRegression.setBeta(1, 1, 3.0d);
        Vector classifyFull5 = onlineLogisticRegression.classifyFull(new DenseVector(new double[]{1.0d, 1.0d}));
        assertEquals(1.0d, classifyFull5.zSum(), 1.0E-8d);
        assertEquals(Math.exp(0.0d) / ((1.0d + Math.exp(0.0d)) + Math.exp(1.0d)), classifyFull5.get(1), 1.0E-8d);
        assertEquals(Math.exp(1.0d) / ((1.0d + Math.exp(0.0d)) + Math.exp(1.0d)), classifyFull5.get(2), 1.0E-8d);
        assertEquals(1.0d / ((1.0d + Math.exp(0.0d)) + Math.exp(1.0d)), classifyFull5.get(0), 1.0E-8d);
    }

    @Test
    public void iris() throws IOException {
        RandomUtils.useTestSeed();
        Splitter on = Splitter.on(",");
        List readLines = Resources.readLines(Resources.getResource("iris.csv"), Charsets.UTF_8);
        ArrayList newArrayList = Lists.newArrayList();
        ArrayList newArrayList2 = Lists.newArrayList();
        Dictionary dictionary = new Dictionary();
        ArrayList newArrayList3 = Lists.newArrayList();
        for (String str : readLines.subList(1, readLines.size())) {
            newArrayList3.add(Integer.valueOf(newArrayList3.size()));
            DenseVector denseVector = new DenseVector(5);
            denseVector.set(0, 1.0d);
            int i = 1;
            Iterable split = on.split(str);
            Iterator it = Iterables.limit(split, 4).iterator();
            while (it.hasNext()) {
                int i2 = i;
                i++;
                denseVector.set(i2, Double.parseDouble((String) it.next()));
            }
            newArrayList.add(denseVector);
            newArrayList2.add(Integer.valueOf(dictionary.intern((String) Iterables.get(split, 4))));
        }
        RandomWrapper random = RandomUtils.getRandom();
        Collections.shuffle(newArrayList3, random);
        List subList = newArrayList3.subList(0, 100);
        List<Integer> subList2 = newArrayList3.subList(100, 150);
        logger.warn("Training set = {}", subList);
        logger.warn("Test set = {}", subList2);
        int[] iArr = new int[subList2.size() + 1];
        for (int i3 = 0; i3 < 200; i3++) {
            OnlineLogisticRegression onlineLogisticRegression = new OnlineLogisticRegression(3, 5, new L2(1.0d));
            for (int i4 = 0; i4 < 30; i4++) {
                Collections.shuffle(subList, random);
                Iterator it2 = subList.iterator();
                while (it2.hasNext()) {
                    int intValue = ((Integer) it2.next()).intValue();
                    onlineLogisticRegression.train(((Integer) newArrayList2.get(intValue)).intValue(), (Vector) newArrayList.get(intValue));
                }
            }
            int i5 = 0;
            int[] iArr2 = new int[3];
            for (Integer num : subList2) {
                int maxValueIndex = onlineLogisticRegression.classifyFull((Vector) newArrayList.get(num.intValue())).maxValueIndex();
                iArr2[maxValueIndex] = iArr2[maxValueIndex] + 1;
                i5 += maxValueIndex == ((Integer) newArrayList2.get(num.intValue())).intValue() ? 1 : 0;
            }
            int i6 = i5;
            iArr[i6] = iArr[i6] + 1;
        }
        for (int i7 = 0; i7 < Math.floor(0.95d * subList2.size()); i7++) {
            assertEquals(String.format("%d trials had unacceptable accuracy of only %.0f%%: ", Integer.valueOf(iArr[i7]), Double.valueOf((100.0d * i7) / subList2.size())), 0L, iArr[i7]);
        }
        assertEquals(String.format("%d trials had unrealistic accuracy of 100%%", Integer.valueOf(iArr[subList2.size() - 1])), 0L, iArr[subList2.size()]);
    }

    @Test
    public void testTrain() throws Exception {
        Vector readStandardData = readStandardData();
        OnlineLogisticRegression learningRate = new OnlineLogisticRegression(2, 8, new L1()).lambda(0.001d).learningRate(50.0d);
        train(getInput(), readStandardData, learningRate);
        test(getInput(), readStandardData, learningRate, 0.05d, 0.3d);
    }

    @Test
    public void testSerializationAndDeSerialization() throws Exception {
        OnlineLogisticRegression decayExponent = new OnlineLogisticRegression(2, 8, new L1()).lambda(0.001d).stepOffset(11).alpha(0.01d).learningRate(50.0d).decayExponent(-0.02d);
        decayExponent.close();
        Closer create = Closer.create();
        try {
            ByteArrayOutputStream byteArrayOutputStream = (ByteArrayOutputStream) create.register(new ByteArrayOutputStream());
            PolymorphicWritable.write((DataOutputStream) create.register(new DataOutputStream(byteArrayOutputStream)), decayExponent);
            byte[] byteArray = byteArrayOutputStream.toByteArray();
            create.close();
            try {
                OnlineLogisticRegression register = create.register(PolymorphicWritable.read((DataInputStream) create.register(new DataInputStream((ByteArrayInputStream) create.register(new ByteArrayInputStream(byteArray)))), OnlineLogisticRegression.class));
                create.close();
                Assert.assertEquals(0.001d, register.getLambda(), 1.0E-7d);
                decayExponent.getClass().getDeclaredField("stepOffset").setAccessible(true);
                Assert.assertEquals(11L, ((Integer) r0.get(decayExponent)).intValue());
                Field declaredField = decayExponent.getClass().getDeclaredField("decayFactor");
                declaredField.setAccessible(true);
                Assert.assertEquals(0.01d, ((Double) declaredField.get(decayExponent)).doubleValue(), 1.0E-7d);
                Field declaredField2 = decayExponent.getClass().getDeclaredField("mu0");
                declaredField2.setAccessible(true);
                Assert.assertEquals(50.0d, ((Double) declaredField2.get(decayExponent)).doubleValue(), 1.0E-7d);
                Field declaredField3 = decayExponent.getClass().getDeclaredField("forgettingExponent");
                declaredField3.setAccessible(true);
                Assert.assertEquals(-0.02d, ((Double) declaredField3.get(decayExponent)).doubleValue(), 1.0E-7d);
            } finally {
            }
        } finally {
        }
    }
}
