/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.classifier.sgd.bankmarketing;

import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.apache.mahout.classifier.evaluation.Auc;
import org.apache.mahout.classifier.sgd.L1;
import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;
import org.apache.mahout.classifier.sgd.PriorFunction;
import org.apache.mahout.classifier.sgd.bankmarketing.TelephoneCall;
import org.apache.mahout.classifier.sgd.bankmarketing.TelephoneCallParser;

public class BankMarketingClassificationMain {
    public static final int NUM_CATEGORIES = 2;

    public static void main(String[] args) throws Exception {
        ArrayList calls = Lists.newArrayList((Iterable)new TelephoneCallParser("bank-full.csv"));
        double heldOutPercentage = 0.1;
        for (int run = 0; run < 20; ++run) {
            Collections.shuffle(calls);
            int cutoff = (int)(heldOutPercentage * (double)calls.size());
            List test = calls.subList(0, cutoff);
            List train = calls.subList(cutoff, calls.size());
            OnlineLogisticRegression lr = new OnlineLogisticRegression(2, 100, (PriorFunction)new L1()).learningRate(1.0).alpha(1.0).lambda(1.0E-6).stepOffset(10000).decayExponent(0.2);
            for (int pass = 0; pass < 20; ++pass) {
                for (TelephoneCall observation : train) {
                    lr.train(observation.getTarget(), observation.asVector());
                }
                if (pass % 5 != 0) continue;
                Auc eval = new Auc(0.5);
                for (TelephoneCall testCall : test) {
                    eval.add(testCall.getTarget(), lr.classifyScalar(testCall.asVector()));
                }
                System.out.printf("%d, %.4f, %.4f\n", pass, lr.currentLearningRate(), eval.auc());
            }
        }
    }
}

