/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.classifier.sgd;

import com.google.common.collect.HashMultiset;
import com.google.common.collect.Lists;
import com.google.common.collect.Ordering;
import java.io.File;
import java.util.ArrayList;
import java.util.Collections;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.fs.PathFilter;
import org.apache.hadoop.io.Text;
import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression;
import org.apache.mahout.classifier.sgd.CrossFoldLearner;
import org.apache.mahout.classifier.sgd.L1;
import org.apache.mahout.classifier.sgd.ModelSerializer;
import org.apache.mahout.classifier.sgd.SGDHelper;
import org.apache.mahout.classifier.sgd.SGDInfo;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.iterator.sequencefile.PathType;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterator;
import org.apache.mahout.ep.State;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.vectorizer.encoders.Dictionary;

public final class TrainASFEmail
extends AbstractJob {
    private TrainASFEmail() {
    }

    @Override
    public int run(String[] args) throws Exception {
        this.addInputOption();
        this.addOutputOption();
        this.addOption("categories", "nc", "The number of categories to train on", true);
        this.addOption("cardinality", "c", "The size of the vectors to use", "100000");
        this.addOption("threads", "t", "The number of threads to use in the learner", "20");
        this.addOption("poolSize", "p", "The number of CrossFoldLearners to use in the AdaptiveLogisticRegression. Higher values require more memory.", "5");
        if (this.parseArguments(args) == null) {
            return -1;
        }
        File base = new File(this.getInputPath().toString());
        HashMultiset overallCounts = HashMultiset.create();
        File output = new File(this.getOutputPath().toString());
        output.mkdirs();
        int numCats = Integer.parseInt(this.getOption("categories"));
        int cardinality = Integer.parseInt(this.getOption("cardinality", "100000"));
        int threadCount = Integer.parseInt(this.getOption("threads", "20"));
        int poolSize = Integer.parseInt(this.getOption("poolSize", "5"));
        Dictionary asfDictionary = new Dictionary();
        AdaptiveLogisticRegression learningAlgorithm = new AdaptiveLogisticRegression(numCats, cardinality, new L1(), threadCount, poolSize);
        learningAlgorithm.setInterval(800);
        learningAlgorithm.setAveragingWindow(500);
        Configuration conf = new Configuration();
        PathFilter trainFilter = new PathFilter(){

            @Override
            public boolean accept(Path path) {
                return path.getName().contains("training");
            }
        };
        SequenceFileDirIterator iter = new SequenceFileDirIterator(new Path(base.toString()), PathType.LIST, trainFilter, null, true, conf);
        long numItems = 0L;
        while (iter.hasNext()) {
            Pair next = (Pair)iter.next();
            asfDictionary.intern(((Text)next.getFirst()).toString());
            ++numItems;
        }
        System.out.println(numItems + " training files");
        SGDInfo info = new SGDInfo();
        iter = new SequenceFileDirIterator(new Path(base.toString()), PathType.LIST, trainFilter, null, true, conf);
        int k = 0;
        while (iter.hasNext()) {
            Pair next = (Pair)iter.next();
            String ng = ((Text)next.getFirst()).toString();
            int actual = asfDictionary.intern(ng);
            learningAlgorithm.train(actual, ((VectorWritable)next.getSecond()).get());
            State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best = learningAlgorithm.getBest();
            SGDHelper.analyzeState(info, 0, ++k, best);
        }
        learningAlgorithm.close();
        System.out.println("exiting main, writing model to " + output);
        ModelSerializer.writeBinary(output + "/asf.model", learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));
        ArrayList<Integer> counts = Lists.newArrayList();
        System.out.println("Word counts");
        for (String count : overallCounts.elementSet()) {
            counts.add(overallCounts.count(count));
        }
        Collections.sort(counts, Ordering.natural().reverse());
        k = 0;
        for (Integer count : counts) {
            System.out.println(k + "\t" + count);
            if (++k <= 1000) continue;
            break;
        }
        return 0;
    }

    public static void main(String[] args) throws Exception {
        TrainASFEmail trainer = new TrainASFEmail();
        trainer.run(args);
    }
}

