/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.classifier.sgd;

import com.google.common.collect.HashMultiset;
import com.google.common.collect.Multiset;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.io.Writer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
import org.apache.commons.cli2.builder.ArgumentBuilder;
import org.apache.commons.cli2.builder.DefaultOptionBuilder;
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.commons.cli2.option.DefaultOption;
import org.apache.commons.cli2.util.HelpFormatter;
import org.apache.commons.io.Charsets;
import org.apache.mahout.classifier.ClassifierResult;
import org.apache.mahout.classifier.NewsgroupHelper;
import org.apache.mahout.classifier.ResultAnalyzer;
import org.apache.mahout.classifier.sgd.ModelSerializer;
import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;
import org.apache.mahout.math.Vector;
import org.apache.mahout.vectorizer.encoders.Dictionary;

public final class TestNewsGroups {
    private String inputFile;
    private String modelFile;

    private TestNewsGroups() {
    }

    public static void main(String[] args) throws IOException {
        TestNewsGroups runner = new TestNewsGroups();
        if (runner.parseArgs(args)) {
            runner.run(new PrintWriter((Writer)new OutputStreamWriter((OutputStream)System.out, Charsets.UTF_8), true));
        }
    }

    public void run(PrintWriter output) throws IOException {
        File base = new File(this.inputFile);
        OnlineLogisticRegression classifier = (OnlineLogisticRegression)ModelSerializer.readBinary((InputStream)new FileInputStream(this.modelFile), OnlineLogisticRegression.class);
        Dictionary newsGroups = new Dictionary();
        HashMultiset overallCounts = HashMultiset.create();
        ArrayList<File> files = new ArrayList<File>();
        for (File newsgroup : base.listFiles()) {
            if (!newsgroup.isDirectory()) continue;
            newsGroups.intern(newsgroup.getName());
            files.addAll(Arrays.asList(newsgroup.listFiles()));
        }
        System.out.println(files.size() + " test files");
        ResultAnalyzer ra = new ResultAnalyzer((Collection)newsGroups.values(), "DEFAULT");
        for (File file : files) {
            String ng = file.getParentFile().getName();
            int actual = newsGroups.intern(ng);
            NewsgroupHelper helper = new NewsgroupHelper();
            Vector input = helper.encodeFeatureVector(file, actual, 0, (Multiset<String>)overallCounts);
            Vector result = classifier.classifyFull(input);
            int cat = result.maxValueIndex();
            double score = result.maxValue();
            double ll = classifier.logLikelihood(actual, input);
            ClassifierResult cr = new ClassifierResult((String)newsGroups.values().get(cat), score, ll);
            ra.addInstance((String)newsGroups.values().get(actual), cr);
        }
        output.println(ra);
    }

    boolean parseArgs(String[] args) {
        DefaultOptionBuilder builder = new DefaultOptionBuilder();
        DefaultOption help = builder.withLongName("help").withDescription("print this list").create();
        ArgumentBuilder argumentBuilder = new ArgumentBuilder();
        DefaultOption inputFileOption = builder.withLongName("input").withRequired(true).withArgument(argumentBuilder.withName("input").withMaximum(1).create()).withDescription("where to get training data").create();
        DefaultOption modelFileOption = builder.withLongName("model").withRequired(true).withArgument(argumentBuilder.withName("model").withMaximum(1).create()).withDescription("where to get a model").create();
        Group normalArgs = new GroupBuilder().withOption((Option)help).withOption((Option)inputFileOption).withOption((Option)modelFileOption).create();
        Parser parser = new Parser();
        parser.setHelpOption((Option)help);
        parser.setHelpTrigger("--help");
        parser.setGroup(normalArgs);
        parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
        CommandLine cmdLine = parser.parseAndHelp(args);
        if (cmdLine == null) {
            return false;
        }
        this.inputFile = (String)cmdLine.getValue((Option)inputFileOption);
        this.modelFile = (String)cmdLine.getValue((Option)modelFileOption);
        return true;
    }
}

