/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.clustering.lda.cvb;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Random;
import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
import org.apache.commons.cli2.OptionException;
import org.apache.commons.cli2.builder.ArgumentBuilder;
import org.apache.commons.cli2.builder.DefaultOptionBuilder;
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.commons.cli2.option.DefaultOption;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.util.Tool;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.clustering.lda.cvb.ModelTrainer;
import org.apache.mahout.clustering.lda.cvb.TopicModel;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.CommandLineUtil;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.apache.mahout.common.iterator.sequencefile.PathFilters;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.DistributedRowMatrixWriter;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.MatrixSlice;
import org.apache.mahout.math.NamedVector;
import org.apache.mahout.math.SparseRowMatrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorIterable;
import org.apache.mahout.math.VectorWritable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class InMemoryCollapsedVariationalBayes0
extends AbstractJob {
    private static final Logger log = LoggerFactory.getLogger(InMemoryCollapsedVariationalBayes0.class);
    private int numTopics;
    private int numTerms;
    private int numDocuments;
    private double alpha;
    private double eta;
    private boolean verbose = false;
    private String[] terms;
    private Matrix corpusWeights;
    private double totalCorpusWeight;
    private double initialModelCorpusFraction;
    private Matrix docTopicCounts;
    private int numTrainingThreads;
    private int numUpdatingThreads;
    private ModelTrainer modelTrainer;

    private InMemoryCollapsedVariationalBayes0() {
    }

    public void setVerbose(boolean verbose) {
        this.verbose = verbose;
    }

    public InMemoryCollapsedVariationalBayes0(Matrix corpus, String[] terms, int numTopics, double alpha, double eta, int numTrainingThreads, int numUpdatingThreads, double modelCorpusFraction) {
        this.numTopics = numTopics;
        this.alpha = alpha;
        this.eta = eta;
        this.corpusWeights = corpus;
        this.numDocuments = corpus.numRows();
        this.terms = terms;
        this.initialModelCorpusFraction = modelCorpusFraction;
        this.numTerms = terms != null ? terms.length : corpus.numCols();
        HashMap<String, Integer> termIdMap = new HashMap<String, Integer>();
        if (terms != null) {
            for (int t = 0; t < terms.length; ++t) {
                termIdMap.put(terms[t], t);
            }
        }
        this.numTrainingThreads = numTrainingThreads;
        this.numUpdatingThreads = numUpdatingThreads;
        this.postInitCorpus();
        this.initializeModel();
    }

    private void postInitCorpus() {
        this.totalCorpusWeight = 0.0;
        int numNonZero = 0;
        for (int i = 0; i < this.numDocuments; ++i) {
            double norm;
            Vector v = this.corpusWeights.viewRow(i);
            if (v == null || (norm = v.norm(1.0)) == 0.0) continue;
            numNonZero += v.getNumNondefaultElements();
            this.totalCorpusWeight += norm;
        }
        String s = "Initializing corpus with %d docs, %d terms, %d nonzero entries, total termWeight %f";
        log.info(String.format(s, this.numDocuments, this.numTerms, numNonZero, this.totalCorpusWeight));
    }

    private void initializeModel() {
        TopicModel topicModel = new TopicModel(this.numTopics, this.numTerms, this.eta, this.alpha, (Random)RandomUtils.getRandom(), this.terms, this.numUpdatingThreads, this.initialModelCorpusFraction == 0.0 ? 1.0 : this.initialModelCorpusFraction * this.totalCorpusWeight);
        topicModel.setConf(this.getConf());
        TopicModel updatedModel = this.initialModelCorpusFraction == 0.0 ? new TopicModel(this.numTopics, this.numTerms, this.eta, this.alpha, null, this.terms, this.numUpdatingThreads, 1.0) : topicModel;
        updatedModel.setConf(this.getConf());
        this.docTopicCounts = new DenseMatrix(this.numDocuments, this.numTopics);
        this.docTopicCounts.assign(1.0 / (double)this.numTopics);
        this.modelTrainer = new ModelTrainer(topicModel, updatedModel, this.numTrainingThreads, this.numTopics, this.numTerms);
    }

    public void trainDocuments() {
        this.trainDocuments(0.0);
    }

    public void trainDocuments(double testFraction) {
        long start = System.nanoTime();
        this.modelTrainer.start();
        for (int docId = 0; docId < this.corpusWeights.numRows(); ++docId) {
            if (testFraction != 0.0 && (double)docId % (1.0 / testFraction) == 0.0) continue;
            Vector docTopics = new DenseVector(this.numTopics).assign(1.0 / (double)this.numTopics);
            this.modelTrainer.trainSync(this.corpusWeights.viewRow(docId), docTopics, true, 10);
        }
        this.modelTrainer.stop();
        InMemoryCollapsedVariationalBayes0.logTime("train documents", System.nanoTime() - start);
    }

    public double iterateUntilConvergence(double minFractionalErrorChange, int maxIterations, int minIter) {
        return this.iterateUntilConvergence(minFractionalErrorChange, maxIterations, minIter, 0.0);
    }

    public double iterateUntilConvergence(double minFractionalErrorChange, int maxIterations, int minIter, double testFraction) {
        int iter;
        double oldPerplexity = 0.0;
        for (iter = 0; iter < minIter; ++iter) {
            this.trainDocuments(testFraction);
            if (this.verbose) {
                log.info("model after: {}: {}", (Object)iter, (Object)this.modelTrainer.getReadModel());
            }
            log.info("iteration {} complete", (Object)iter);
            oldPerplexity = this.modelTrainer.calculatePerplexity((VectorIterable)this.corpusWeights, (VectorIterable)this.docTopicCounts, testFraction);
            log.info("{} = perplexity", (Object)oldPerplexity);
        }
        double newPerplexity = 0.0;
        double fractionalChange = Double.MAX_VALUE;
        while (iter < maxIterations && fractionalChange > minFractionalErrorChange) {
            this.trainDocuments();
            if (this.verbose) {
                log.info("model after: {}: {}", (Object)iter, (Object)this.modelTrainer.getReadModel());
            }
            newPerplexity = this.modelTrainer.calculatePerplexity((VectorIterable)this.corpusWeights, (VectorIterable)this.docTopicCounts, testFraction);
            log.info("{} = perplexity", (Object)newPerplexity);
            ++iter;
            fractionalChange = Math.abs(newPerplexity - oldPerplexity) / oldPerplexity;
            log.info("{} = fractionalChange", (Object)fractionalChange);
            oldPerplexity = newPerplexity;
        }
        if (iter < maxIterations) {
            log.info(String.format("Converged! fractional error change: %f, error %f", fractionalChange, newPerplexity));
        } else {
            log.info(String.format("Reached max iteration count (%d), fractional error change: %f, error: %f", maxIterations, fractionalChange, newPerplexity));
        }
        return newPerplexity;
    }

    public void writeModel(Path outputPath) throws IOException {
        this.modelTrainer.persist(outputPath);
    }

    private static void logTime(String label, long nanos) {
        log.info("{} time: {}ms", (Object)label, (Object)((double)nanos / 1000000.0));
    }

    public static int main2(String[] args, Configuration conf) throws Exception {
        DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
        ArgumentBuilder abuilder = new ArgumentBuilder();
        GroupBuilder gbuilder = new GroupBuilder();
        Option helpOpt = DefaultOptionCreator.helpOption();
        DefaultOption inputDirOpt = obuilder.withLongName("input").withRequired(true).withArgument(abuilder.withName("input").withMinimum(1).withMaximum(1).create()).withDescription("The Directory on MapR-FS containing the collapsed, properly formatted files having one doc per line").withShortName("i").create();
        DefaultOption dictOpt = obuilder.withLongName("dictionary").withRequired(false).withArgument(abuilder.withName("dictionary").withMinimum(1).withMaximum(1).create()).withDescription("The path to the term-dictionary format is ... ").withShortName("d").create();
        DefaultOption dfsOpt = obuilder.withLongName("dfs").withRequired(false).withArgument(abuilder.withName("dfs").withMinimum(1).withMaximum(1).create()).withDescription("MapR-FS namenode URI").withShortName("dfs").create();
        DefaultOption numTopicsOpt = obuilder.withLongName("numTopics").withRequired(true).withArgument(abuilder.withName("numTopics").withMinimum(1).withMaximum(1).create()).withDescription("Number of topics to learn").withShortName("top").create();
        DefaultOption outputTopicFileOpt = obuilder.withLongName("topicOutputFile").withRequired(true).withArgument(abuilder.withName("topicOutputFile").withMinimum(1).withMaximum(1).create()).withDescription("File to write out p(term | topic)").withShortName("to").create();
        DefaultOption outputDocFileOpt = obuilder.withLongName("docOutputFile").withRequired(true).withArgument(abuilder.withName("docOutputFile").withMinimum(1).withMaximum(1).create()).withDescription("File to write out p(topic | docid)").withShortName("do").create();
        DefaultOption alphaOpt = obuilder.withLongName("alpha").withRequired(false).withArgument(abuilder.withName("alpha").withMinimum(1).withMaximum(1).withDefault((Object)"0.1").create()).withDescription("Smoothing parameter for p(topic | document) prior").withShortName("a").create();
        DefaultOption etaOpt = obuilder.withLongName("eta").withRequired(false).withArgument(abuilder.withName("eta").withMinimum(1).withMaximum(1).withDefault((Object)"0.1").create()).withDescription("Smoothing parameter for p(term | topic)").withShortName("e").create();
        DefaultOption maxIterOpt = obuilder.withLongName("maxIterations").withRequired(false).withArgument(abuilder.withName("maxIterations").withMinimum(1).withMaximum(1).withDefault((Object)"10").create()).withDescription("Maximum number of training passes").withShortName("m").create();
        DefaultOption modelCorpusFractionOption = obuilder.withLongName("modelCorpusFraction").withRequired(false).withArgument(abuilder.withName("modelCorpusFraction").withMinimum(1).withMaximum(1).withDefault((Object)"0.0").create()).withShortName("mcf").withDescription("For online updates, initial value of |model|/|corpus|").create();
        DefaultOption burnInOpt = obuilder.withLongName("burnInIterations").withRequired(false).withArgument(abuilder.withName("burnInIterations").withMinimum(1).withMaximum(1).withDefault((Object)"5").create()).withDescription("Minimum number of iterations").withShortName("b").create();
        DefaultOption convergenceOpt = obuilder.withLongName("convergence").withRequired(false).withArgument(abuilder.withName("convergence").withMinimum(1).withMaximum(1).withDefault((Object)"0.0").create()).withDescription("Fractional rate of perplexity to consider convergence").withShortName("c").create();
        DefaultOption reInferDocTopicsOpt = obuilder.withLongName("reInferDocTopics").withRequired(false).withArgument(abuilder.withName("reInferDocTopics").withMinimum(1).withMaximum(1).withDefault((Object)"no").create()).withDescription("re-infer p(topic | doc) : [no | randstart | continue]").withShortName("rdt").create();
        DefaultOption numTrainThreadsOpt = obuilder.withLongName("numTrainThreads").withRequired(false).withArgument(abuilder.withName("numTrainThreads").withMinimum(1).withMaximum(1).withDefault((Object)"1").create()).withDescription("number of threads to train with").withShortName("ntt").create();
        DefaultOption numUpdateThreadsOpt = obuilder.withLongName("numUpdateThreads").withRequired(false).withArgument(abuilder.withName("numUpdateThreads").withMinimum(1).withMaximum(1).withDefault((Object)"1").create()).withDescription("number of threads to update the model with").withShortName("nut").create();
        DefaultOption verboseOpt = obuilder.withLongName("verbose").withRequired(false).withArgument(abuilder.withName("verbose").withMinimum(1).withMaximum(1).withDefault((Object)"false").create()).withDescription("print verbose information, like top-terms in each topic, during iteration").withShortName("v").create();
        Group group = gbuilder.withName("Options").withOption((Option)inputDirOpt).withOption((Option)numTopicsOpt).withOption((Option)alphaOpt).withOption((Option)etaOpt).withOption((Option)maxIterOpt).withOption((Option)burnInOpt).withOption((Option)convergenceOpt).withOption((Option)dictOpt).withOption((Option)reInferDocTopicsOpt).withOption((Option)outputDocFileOpt).withOption((Option)outputTopicFileOpt).withOption((Option)dfsOpt).withOption((Option)numTrainThreadsOpt).withOption((Option)numUpdateThreadsOpt).withOption((Option)modelCorpusFractionOption).withOption((Option)verboseOpt).create();
        try {
            Parser parser = new Parser();
            parser.setGroup(group);
            parser.setHelpOption(helpOpt);
            CommandLine cmdLine = parser.parse(args);
            if (cmdLine.hasOption(helpOpt)) {
                CommandLineUtil.printHelp(group);
                return -1;
            }
            String inputDirString = (String)cmdLine.getValue((Option)inputDirOpt);
            String dictDirString = cmdLine.hasOption((Option)dictOpt) ? (String)cmdLine.getValue((Option)dictOpt) : null;
            int numTopics = Integer.parseInt((String)cmdLine.getValue((Option)numTopicsOpt));
            double alpha = Double.parseDouble((String)cmdLine.getValue((Option)alphaOpt));
            double eta = Double.parseDouble((String)cmdLine.getValue((Option)etaOpt));
            int maxIterations = Integer.parseInt((String)cmdLine.getValue((Option)maxIterOpt));
            int burnInIterations = Integer.parseInt((String)cmdLine.getValue((Option)burnInOpt));
            double minFractionalErrorChange = Double.parseDouble((String)cmdLine.getValue((Option)convergenceOpt));
            int numTrainThreads = Integer.parseInt((String)cmdLine.getValue((Option)numTrainThreadsOpt));
            int numUpdateThreads = Integer.parseInt((String)cmdLine.getValue((Option)numUpdateThreadsOpt));
            String topicOutFile = (String)cmdLine.getValue((Option)outputTopicFileOpt);
            String docOutFile = (String)cmdLine.getValue((Option)outputDocFileOpt);
            boolean verbose = Boolean.parseBoolean((String)cmdLine.getValue((Option)verboseOpt));
            double modelCorpusFraction = Double.parseDouble((String)cmdLine.getValue((Option)modelCorpusFractionOption));
            long start = System.nanoTime();
            if (conf.get("fs.default.name") == null) {
                String dfsNameNode = (String)cmdLine.getValue((Option)dfsOpt);
                conf.set("fs.default.name", dfsNameNode);
            }
            String[] terms = InMemoryCollapsedVariationalBayes0.loadDictionary(dictDirString, conf);
            InMemoryCollapsedVariationalBayes0.logTime("dictionary loading", System.nanoTime() - start);
            start = System.nanoTime();
            Matrix corpus = InMemoryCollapsedVariationalBayes0.loadVectors(inputDirString, conf);
            InMemoryCollapsedVariationalBayes0.logTime("vector seqfile corpus loading", System.nanoTime() - start);
            start = System.nanoTime();
            InMemoryCollapsedVariationalBayes0 cvb0 = new InMemoryCollapsedVariationalBayes0(corpus, terms, numTopics, alpha, eta, numTrainThreads, numUpdateThreads, modelCorpusFraction);
            InMemoryCollapsedVariationalBayes0.logTime("cvb0 init", System.nanoTime() - start);
            start = System.nanoTime();
            cvb0.setVerbose(verbose);
            cvb0.iterateUntilConvergence(minFractionalErrorChange, maxIterations, burnInIterations);
            InMemoryCollapsedVariationalBayes0.logTime("total training time", System.nanoTime() - start);
            start = System.nanoTime();
            cvb0.writeModel(new Path(topicOutFile));
            DistributedRowMatrixWriter.write(new Path(docOutFile), conf, (Iterable<MatrixSlice>)cvb0.docTopicCounts);
            InMemoryCollapsedVariationalBayes0.logTime("printTopics", System.nanoTime() - start);
        }
        catch (OptionException e) {
            log.error("Error while parsing options", (Throwable)e);
            CommandLineUtil.printHelp(group);
        }
        return 0;
    }

    private static String[] loadDictionary(String dictionaryPath, Configuration conf) {
        if (dictionaryPath == null) {
            return null;
        }
        Path dictionaryFile = new Path(dictionaryPath);
        ArrayList<Pair<Integer, String>> termList = new ArrayList<Pair<Integer, String>>();
        int maxTermId = 0;
        for (Pair record : new SequenceFileIterable(dictionaryFile, true, conf)) {
            termList.add(new Pair<Integer, String>(((IntWritable)record.getSecond()).get(), ((Writable)record.getFirst()).toString()));
            maxTermId = Math.max(maxTermId, ((IntWritable)record.getSecond()).get());
        }
        String[] terms = new String[maxTermId + 1];
        for (Pair pair : termList) {
            terms[((Integer)pair.getFirst()).intValue()] = (String)pair.getSecond();
        }
        return terms;
    }

    @Override
    public Configuration getConf() {
        return super.getConf();
    }

    private static Matrix loadVectors(String vectorPathString, Configuration conf) throws IOException {
        Path vectorPath = new Path(vectorPathString);
        FileSystem fs = vectorPath.getFileSystem(conf);
        ArrayList<Path> subPaths = new ArrayList<Path>();
        if (fs.isFile(vectorPath)) {
            subPaths.add(vectorPath);
        } else {
            for (FileStatus fileStatus : fs.listStatus(vectorPath, PathFilters.logsCRCFilter())) {
                subPaths.add(fileStatus.getPath());
            }
        }
        ArrayList<Pair<Integer, Vector>> rowList = new ArrayList<Pair<Integer, Vector>>();
        int numRows = Integer.MIN_VALUE;
        int numCols = -1;
        boolean sequentialAccess = false;
        for (Path subPath : subPaths) {
            for (Pair record : new SequenceFileIterable(subPath, true, conf)) {
                int id = ((IntWritable)record.getFirst()).get();
                Vector vector = ((VectorWritable)record.getSecond()).get();
                if (vector instanceof NamedVector) {
                    vector = ((NamedVector)vector).getDelegate();
                }
                if (numCols < 0) {
                    numCols = vector.size();
                    sequentialAccess = vector.isSequentialAccess();
                }
                rowList.add(Pair.of(id, vector));
                numRows = Math.max(numRows, id);
            }
        }
        Vector[] rowVectors = new Vector[++numRows];
        for (Pair pair : rowList) {
            rowVectors[((Integer)pair.getFirst()).intValue()] = (Vector)pair.getSecond();
        }
        return new SparseRowMatrix(numRows, numCols, rowVectors, true, !sequentialAccess);
    }

    public int run(String[] strings) throws Exception {
        return InMemoryCollapsedVariationalBayes0.main2(strings, this.getConf());
    }

    public static void main(String[] args) throws Exception {
        ToolRunner.run((Tool)new InMemoryCollapsedVariationalBayes0(), (String[])args);
    }
}

