package org.apache.mahout.classifier.sgd.bankmarketing;

import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.apache.mahout.classifier.NewsgroupHelper;
import org.apache.mahout.classifier.evaluation.Auc;
import org.apache.mahout.classifier.sgd.L1;
import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;

/* loaded from: input_file:org/apache/mahout/classifier/sgd/bankmarketing/BankMarketingClassificationMain.class */
public class BankMarketingClassificationMain {
    public static final int NUM_CATEGORIES = 2;

    public static void main(String[] strArr) throws Exception {
        ArrayList newArrayList = Lists.newArrayList(new TelephoneCallParser("bank-full.csv"));
        for (int i = 0; i < 20; i++) {
            Collections.shuffle(newArrayList);
            int size = (int) (0.1d * newArrayList.size());
            List<TelephoneCall> subList = newArrayList.subList(0, size);
            List<TelephoneCall> subList2 = newArrayList.subList(size, newArrayList.size());
            OnlineLogisticRegression decayExponent = new OnlineLogisticRegression(2, 100, new L1()).learningRate(1.0d).alpha(1.0d).lambda(1.0E-6d).stepOffset(NewsgroupHelper.FEATURES).decayExponent(0.2d);
            for (int i2 = 0; i2 < 20; i2++) {
                for (TelephoneCall telephoneCall : subList2) {
                    decayExponent.train(telephoneCall.getTarget(), telephoneCall.asVector());
                }
                if (i2 % 5 == 0) {
                    Auc auc = new Auc(0.5d);
                    for (TelephoneCall telephoneCall2 : subList) {
                        auc.add(telephoneCall2.getTarget(), decayExponent.classifyScalar(telephoneCall2.asVector()));
                    }
                    System.out.printf("%d, %.4f, %.4f\n", Integer.valueOf(i2), Double.valueOf(decayExponent.currentLearningRate()), Double.valueOf(auc.auc()));
                }
            }
        }
    }
}
