package org.apache.mahout.clustering.dirichlet;

import com.google.common.collect.Lists;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.clustering.Cluster;
import org.apache.mahout.clustering.Model;
import org.apache.mahout.clustering.ModelDistribution;
import org.apache.mahout.clustering.classify.ClusterClassificationDriver;
import org.apache.mahout.clustering.classify.ClusterClassifier;
import org.apache.mahout.clustering.dirichlet.models.DistributionDescription;
import org.apache.mahout.clustering.dirichlet.models.GaussianClusterDistribution;
import org.apache.mahout.clustering.iterator.ClusterIterator;
import org.apache.mahout.clustering.iterator.DirichletClusteringPolicy;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.HadoopUtil;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.apache.mahout.common.iterator.sequencefile.PathFilters;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterable;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.VectorWritable;

@Deprecated
/* loaded from: input_file:org/apache/mahout/clustering/dirichlet/DirichletDriver.class */
public class DirichletDriver extends AbstractJob {
    public static final String STATE_IN_KEY = "org.apache.mahout.clustering.dirichlet.stateIn";
    public static final String MODEL_DISTRIBUTION_KEY = "org.apache.mahout.clustering.dirichlet.modelFactory";
    public static final String NUM_CLUSTERS_KEY = "org.apache.mahout.clustering.dirichlet.numClusters";
    public static final String ALPHA_0_KEY = "org.apache.mahout.clustering.dirichlet.alpha_0";
    public static final String EMIT_MOST_LIKELY_KEY = "org.apache.mahout.clustering.dirichlet.emitMostLikely";
    public static final String THRESHOLD_KEY = "org.apache.mahout.clustering.dirichlet.threshold";
    public static final String MODEL_PROTOTYPE_CLASS_OPTION = "modelPrototype";
    public static final String MODEL_DISTRIBUTION_CLASS_OPTION = "modelDist";
    public static final String ALPHA_OPTION = "alpha";

    public static void main(String[] strArr) throws Exception {
        ToolRunner.run(new Configuration(), new DirichletDriver(), strArr);
    }

    public int run(String[] strArr) throws Exception {
        addInputOption();
        addOutputOption();
        addOption(DefaultOptionCreator.maxIterationsOption().create());
        addOption(DefaultOptionCreator.numClustersOption().withRequired(true).create());
        addOption(DefaultOptionCreator.overwriteOption().create());
        addOption(DefaultOptionCreator.clusteringOption().create());
        addOption(ALPHA_OPTION, "a0", "The alpha0 value for the DirichletDistribution. Defaults to 1.0", "1.0");
        addOption(MODEL_DISTRIBUTION_CLASS_OPTION, "md", "The ModelDistribution class name. Defaults to GaussianClusterDistribution", GaussianClusterDistribution.class.getName());
        addOption(MODEL_PROTOTYPE_CLASS_OPTION, "mp", "The ModelDistribution prototype Vector class name. Defaults to RandomAccessSparseVector", RandomAccessSparseVector.class.getName());
        addOption(DefaultOptionCreator.distanceMeasureOption().withRequired(false).create());
        addOption(DefaultOptionCreator.emitMostLikelyOption().create());
        addOption(DefaultOptionCreator.thresholdOption().create());
        addOption(DefaultOptionCreator.methodOption().create());
        if (parseArguments(strArr) == null) {
            return -1;
        }
        Path inputPath = getInputPath();
        Path outputPath = getOutputPath();
        if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) {
            HadoopUtil.delete(getConf(), outputPath);
        }
        String option = getOption(MODEL_DISTRIBUTION_CLASS_OPTION);
        String option2 = getOption(MODEL_PROTOTYPE_CLASS_OPTION);
        String option3 = getOption(DefaultOptionCreator.DISTANCE_MEASURE_OPTION);
        int parseInt = Integer.parseInt(getOption(DefaultOptionCreator.NUM_CLUSTERS_OPTION));
        int parseInt2 = Integer.parseInt(getOption(DefaultOptionCreator.MAX_ITERATIONS_OPTION));
        boolean parseBoolean = Boolean.parseBoolean(getOption(DefaultOptionCreator.EMIT_MOST_LIKELY_OPTION));
        double parseDouble = Double.parseDouble(getOption(DefaultOptionCreator.THRESHOLD_OPTION));
        run(getConf(), inputPath, outputPath, new DistributionDescription(option, option2, option3, readPrototypeSize(inputPath)), parseInt, parseInt2, Double.parseDouble(getOption(ALPHA_OPTION)), hasOption(DefaultOptionCreator.CLUSTERING_OPTION), parseBoolean, parseDouble, getOption("method").equalsIgnoreCase("sequential"));
        return 0;
    }

    public static void run(Configuration configuration, Path path, Path path2, DistributionDescription distributionDescription, int i, int i2, double d, boolean z, boolean z2, double d2, boolean z3) throws IOException, ClassNotFoundException, InterruptedException {
        Path buildClusters = buildClusters(configuration, path, path2, distributionDescription, i, i2, d, z3);
        if (z) {
            clusterData(configuration, path, buildClusters, path2, d, i, z2, d2, z3);
        }
    }

    public static int readPrototypeSize(Path path) throws IOException {
        Configuration configuration = new Configuration();
        FileStatus[] listStatus = FileSystem.get(path.toUri(), configuration).listStatus(path, PathFilters.logsCRCFilter());
        int i = 0;
        if (listStatus.length > 0) {
            Iterator it = new SequenceFileValueIterable(listStatus[0].getPath(), true, configuration).iterator();
            while (it.hasNext()) {
                i = ((VectorWritable) it.next()).get().size();
            }
        }
        return i;
    }

    public static Path buildClusters(Configuration configuration, Path path, Path path2, DistributionDescription distributionDescription, int i, int i2, double d, boolean z) throws IOException, ClassNotFoundException, InterruptedException {
        Path path3 = new Path(path2, Cluster.INITIAL_CLUSTERS_DIR);
        ModelDistribution<VectorWritable> createModelDistribution = distributionDescription.createModelDistribution(configuration);
        ArrayList newArrayList = Lists.newArrayList();
        for (Model<VectorWritable> model : createModelDistribution.sampleFromPrior(i)) {
            newArrayList.add((Cluster) model);
        }
        new ClusterClassifier(newArrayList, new DirichletClusteringPolicy(i, d)).writeToSeqFiles(path3);
        if (z) {
            ClusterIterator.iterateSeq(configuration, path, path3, path2, i2);
        } else {
            ClusterIterator.iterateMR(configuration, path, path3, path2, i2);
        }
        return path2;
    }

    public static void clusterData(Configuration configuration, Path path, Path path2, Path path3, double d, int i, boolean z, double d2, boolean z2) throws IOException, InterruptedException, ClassNotFoundException {
        ClusterClassifier.writePolicy(new DirichletClusteringPolicy(i, d), path2);
        ClusterClassificationDriver.run(configuration, path, path3, new Path(path3, "clusteredPoints"), d2, z, z2);
    }
}
