package org.apache.mahout.classifier.sgd;

import com.google.common.base.Charsets;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileOutputStream;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.io.Writer;
import java.util.HashMap;
import java.util.Map;
import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
import org.apache.commons.cli2.builder.ArgumentBuilder;
import org.apache.commons.cli2.builder.DefaultOptionBuilder;
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.commons.cli2.option.DefaultOption;
import org.apache.commons.cli2.util.HelpFormatter;
import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression;
import org.apache.mahout.ep.State;
import org.apache.mahout.math.SequentialAccessSparseVector;
import org.apache.mahout.math.Vector;

/* loaded from: input_file:org/apache/mahout/classifier/sgd/RunAdaptiveLogistic.class */
public final class RunAdaptiveLogistic {
    private static String inputFile;
    private static String modelFile;
    private static String outputFile;
    private static String idColumn;
    private static boolean maxScoreOnly;

    private RunAdaptiveLogistic() {
    }

    public static void main(String[] strArr) throws Exception {
        mainToOutput(strArr, new PrintWriter((Writer) new OutputStreamWriter(System.out, Charsets.UTF_8), true));
    }

    static void mainToOutput(String[] strArr, PrintWriter printWriter) throws Exception {
        if (parseArgs(strArr)) {
            AdaptiveLogisticModelParameters loadFromFile = AdaptiveLogisticModelParameters.loadFromFile(new File(modelFile));
            CsvRecordFactory csvRecordFactory = loadFromFile.getCsvRecordFactory();
            csvRecordFactory.setIdName(idColumn);
            State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best = loadFromFile.createAdaptiveLogisticRegression().getBest();
            if (best == null) {
                printWriter.println("AdaptiveLogisticRegression has not be trained probably.");
                return;
            }
            CrossFoldLearner learner = best.getPayload().getLearner();
            BufferedReader open = TrainAdaptiveLogistic.open(inputFile);
            BufferedWriter bufferedWriter = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(outputFile), Charsets.UTF_8));
            bufferedWriter.write(idColumn + ",target,score");
            bufferedWriter.newLine();
            csvRecordFactory.firstLine(open.readLine());
            HashMap hashMap = new HashMap();
            int i = 0;
            for (String readLine = open.readLine(); readLine != null; readLine = open.readLine()) {
                SequentialAccessSparseVector sequentialAccessSparseVector = new SequentialAccessSparseVector(loadFromFile.getNumFeatures());
                csvRecordFactory.processLine(readLine, sequentialAccessSparseVector, false);
                Vector classifyFull = learner.classifyFull(sequentialAccessSparseVector);
                hashMap.clear();
                if (maxScoreOnly) {
                    hashMap.put(csvRecordFactory.getTargetLabel(classifyFull.maxValueIndex()), Double.valueOf(classifyFull.maxValue()));
                } else {
                    for (int i2 = 0; i2 < classifyFull.size(); i2++) {
                        hashMap.put(csvRecordFactory.getTargetLabel(i2), Double.valueOf(classifyFull.get(i2)));
                    }
                }
                for (Map.Entry entry : hashMap.entrySet()) {
                    bufferedWriter.write(csvRecordFactory.getIdString(readLine) + ',' + ((String) entry.getKey()) + ',' + entry.getValue());
                    bufferedWriter.newLine();
                }
                i++;
                if (i % 100 == 0) {
                    printWriter.println(i + " records processed");
                }
            }
            bufferedWriter.flush();
            bufferedWriter.close();
            printWriter.println(i + " records processed totally.");
        }
    }

    private static boolean parseArgs(String[] strArr) {
        DefaultOptionBuilder defaultOptionBuilder = new DefaultOptionBuilder();
        DefaultOption create = defaultOptionBuilder.withLongName("help").withDescription("print this list").create();
        DefaultOption create2 = defaultOptionBuilder.withLongName("quiet").withDescription("be extra quiet").create();
        ArgumentBuilder argumentBuilder = new ArgumentBuilder();
        DefaultOption create3 = defaultOptionBuilder.withLongName("input").withRequired(true).withArgument(argumentBuilder.withName("input").withMaximum(1).create()).withDescription("where to get training data").create();
        DefaultOption create4 = defaultOptionBuilder.withLongName("model").withRequired(true).withArgument(argumentBuilder.withName("model").withMaximum(1).create()).withDescription("where to get the trained model").create();
        DefaultOption create5 = defaultOptionBuilder.withLongName("output").withRequired(true).withDescription("the file path to output scores").withArgument(argumentBuilder.withName("output").withMaximum(1).create()).create();
        DefaultOption create6 = defaultOptionBuilder.withLongName("idcolumn").withRequired(true).withDescription("the name of the id column for each record").withArgument(argumentBuilder.withName("idcolumn").withMaximum(1).create()).create();
        DefaultOption create7 = defaultOptionBuilder.withLongName("maxscoreonly").withDescription("only output the target label with max scores").create();
        Group create8 = new GroupBuilder().withOption(create).withOption(create2).withOption(create3).withOption(create4).withOption(create5).withOption(create6).withOption(create7).create();
        Parser parser = new Parser();
        parser.setHelpOption(create);
        parser.setHelpTrigger("--help");
        parser.setGroup(create8);
        parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
        CommandLine parseAndHelp = parser.parseAndHelp(strArr);
        if (parseAndHelp == null) {
            return false;
        }
        inputFile = getStringArgument(parseAndHelp, create3);
        modelFile = getStringArgument(parseAndHelp, create4);
        outputFile = getStringArgument(parseAndHelp, create5);
        idColumn = getStringArgument(parseAndHelp, create6);
        maxScoreOnly = getBooleanArgument(parseAndHelp, create7);
        return true;
    }

    private static boolean getBooleanArgument(CommandLine commandLine, Option option) {
        return commandLine.hasOption(option);
    }

    private static String getStringArgument(CommandLine commandLine, Option option) {
        return (String) commandLine.getValue(option);
    }
}
