/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.classifier.sgd;

import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.Locale;
import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
import org.apache.commons.cli2.builder.ArgumentBuilder;
import org.apache.commons.cli2.builder.DefaultOptionBuilder;
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.commons.cli2.option.DefaultOption;
import org.apache.commons.cli2.util.HelpFormatter;
import org.apache.mahout.classifier.ConfusionMatrix;
import org.apache.mahout.classifier.evaluation.Auc;
import org.apache.mahout.classifier.sgd.AdaptiveLogisticModelParameters;
import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression;
import org.apache.mahout.classifier.sgd.CrossFoldLearner;
import org.apache.mahout.classifier.sgd.CsvRecordFactory;
import org.apache.mahout.classifier.sgd.TrainLogistic;
import org.apache.mahout.ep.State;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.SequentialAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.stats.OnlineSummarizer;

public final class ValidateAdaptiveLogistic {
    private static String inputFile;
    private static String modelFile;
    private static String defaultCategory;
    private static boolean showAuc;
    private static boolean showScores;
    private static boolean showConfusion;

    private ValidateAdaptiveLogistic() {
    }

    public static void main(String[] args) throws IOException {
        ValidateAdaptiveLogistic.mainToOutput(args, new PrintWriter(System.out, true));
    }

    static void mainToOutput(String[] args, PrintWriter output) throws IOException {
        if (ValidateAdaptiveLogistic.parseArgs(args)) {
            if (!(showAuc || showConfusion || showScores)) {
                showAuc = true;
                showConfusion = true;
            }
            Auc collector = null;
            AdaptiveLogisticModelParameters lmp = AdaptiveLogisticModelParameters.loadFromFile(new File(modelFile));
            CsvRecordFactory csv = lmp.getCsvRecordFactory();
            AdaptiveLogisticRegression lr = lmp.createAdaptiveLogisticRegression();
            if (lmp.getTargetCategories().size() <= 2) {
                collector = new Auc();
            }
            OnlineSummarizer slh = new OnlineSummarizer();
            ConfusionMatrix cm = new ConfusionMatrix(lmp.getTargetCategories(), defaultCategory);
            State best = lr.getBest();
            if (best == null) {
                output.printf("%s\n", "AdaptiveLogisticRegression has not be trained probably.");
                return;
            }
            CrossFoldLearner learner = ((AdaptiveLogisticRegression.Wrapper)best.getPayload()).getLearner();
            BufferedReader in = TrainLogistic.open(inputFile);
            String line = in.readLine();
            csv.firstLine(line);
            line = in.readLine();
            if (showScores) {
                output.printf(Locale.ENGLISH, "\"%s\", \"%s\", \"%s\", \"%s\"\n", "target", "model-output", "log-likelihood", "average-likelihood");
            }
            while (line != null) {
                SequentialAccessSparseVector v = new SequentialAccessSparseVector(lmp.getNumFeatures());
                int target = csv.processLine(line, (Vector)v);
                double likelihood = learner.logLikelihood(target, (Vector)v);
                double score = learner.classifyFull((Vector)v).maxValue();
                slh.add(likelihood);
                cm.addInstance(csv.getTargetString((CharSequence)line), csv.getTargetLabel(target));
                if (showScores) {
                    output.printf(Locale.ENGLISH, "%8d, %.12f, %.13f, %.13f\n", target, score, learner.logLikelihood(target, (Vector)v), slh.getMean());
                }
                if (collector != null) {
                    collector.add(target, score);
                }
                line = in.readLine();
            }
            output.printf(Locale.ENGLISH, "\nLog-likelihood:", new Object[0]);
            output.printf(Locale.ENGLISH, "Min=%.2f, Max=%.2f, Mean=%.2f, Median=%.2f\n", slh.getMin(), slh.getMax(), slh.getMean(), slh.getMedian());
            if (collector != null) {
                output.printf(Locale.ENGLISH, "\nAUC = %.2f\n", collector.auc());
            }
            if (showConfusion) {
                output.printf(Locale.ENGLISH, "\n%s\n\n", cm.toString());
                if (collector != null) {
                    Matrix m = collector.entropy();
                    output.printf(Locale.ENGLISH, "Entropy Matrix: [[%.1f, %.1f], [%.1f, %.1f]]\n", m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1));
                }
            }
        }
    }

    private static boolean parseArgs(String[] args) {
        DefaultOptionBuilder builder = new DefaultOptionBuilder();
        DefaultOption help = builder.withLongName("help").withDescription("print this list").create();
        DefaultOption quiet = builder.withLongName("quiet").withDescription("be extra quiet").create();
        DefaultOption auc = builder.withLongName("auc").withDescription("print AUC").create();
        DefaultOption confusion = builder.withLongName("confusion").withDescription("print confusion matrix").create();
        DefaultOption scores = builder.withLongName("scores").withDescription("print scores").create();
        ArgumentBuilder argumentBuilder = new ArgumentBuilder();
        DefaultOption inputFileOption = builder.withLongName("input").withRequired(true).withArgument(argumentBuilder.withName("input").withMaximum(1).create()).withDescription("where to get validate data").create();
        DefaultOption modelFileOption = builder.withLongName("model").withRequired(true).withArgument(argumentBuilder.withName("model").withMaximum(1).create()).withDescription("where to get the trained model").create();
        DefaultOption defaultCagetoryOption = builder.withLongName("defaultCategory").withRequired(false).withArgument(argumentBuilder.withName("defaultCategory").withMaximum(1).withDefault((Object)"unknown").create()).withDescription("the default category value to use").create();
        Group normalArgs = new GroupBuilder().withOption((Option)help).withOption((Option)quiet).withOption((Option)auc).withOption((Option)scores).withOption((Option)confusion).withOption((Option)inputFileOption).withOption((Option)modelFileOption).withOption((Option)defaultCagetoryOption).create();
        Parser parser = new Parser();
        parser.setHelpOption((Option)help);
        parser.setHelpTrigger("--help");
        parser.setGroup(normalArgs);
        parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
        CommandLine cmdLine = parser.parseAndHelp(args);
        if (cmdLine == null) {
            return false;
        }
        inputFile = ValidateAdaptiveLogistic.getStringArgument(cmdLine, (Option)inputFileOption);
        modelFile = ValidateAdaptiveLogistic.getStringArgument(cmdLine, (Option)modelFileOption);
        defaultCategory = ValidateAdaptiveLogistic.getStringArgument(cmdLine, (Option)defaultCagetoryOption);
        showAuc = ValidateAdaptiveLogistic.getBooleanArgument(cmdLine, (Option)auc);
        showScores = ValidateAdaptiveLogistic.getBooleanArgument(cmdLine, (Option)scores);
        showConfusion = ValidateAdaptiveLogistic.getBooleanArgument(cmdLine, (Option)confusion);
        return true;
    }

    private static boolean getBooleanArgument(CommandLine cmdLine, Option option) {
        return cmdLine.hasOption(option);
    }

    private static String getStringArgument(CommandLine cmdLine, Option inputFile) {
        return (String)cmdLine.getValue(inputFile);
    }
}

