package org.apache.mahout.clustering.classify;

import com.google.common.collect.ForwardingIterator;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.clustering.Cluster;
import org.apache.mahout.clustering.iterator.ClusterWritable;
import org.apache.mahout.clustering.iterator.DistanceMeasureCluster;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.apache.mahout.common.iterator.sequencefile.PathFilters;
import org.apache.mahout.common.iterator.sequencefile.PathType;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterator;
import org.apache.mahout.math.NamedVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;

/* loaded from: input_file:org/apache/mahout/clustering/classify/ClusterClassificationDriver.class */
public final class ClusterClassificationDriver extends AbstractJob {
    public int run(String[] strArr) throws Exception {
        addInputOption();
        addOutputOption();
        addOption(DefaultOptionCreator.methodOption().create());
        addOption(DefaultOptionCreator.clustersInOption().withDescription("The input centroids, as Vectors.  Must be a SequenceFile of Writable, Cluster/Canopy.").create());
        if (parseArguments(strArr) == null) {
            return -1;
        }
        Path inputPath = getInputPath();
        Path outputPath = getOutputPath();
        if (getConf() == null) {
            setConf(new Configuration());
        }
        Path path = new Path(getOption(DefaultOptionCreator.CLUSTERS_IN_OPTION));
        boolean equalsIgnoreCase = getOption(DefaultOptionCreator.METHOD_OPTION).equalsIgnoreCase("sequential");
        double d = 0.0d;
        if (hasOption(DefaultOptionCreator.OUTLIER_THRESHOLD)) {
            d = Double.parseDouble(getOption(DefaultOptionCreator.OUTLIER_THRESHOLD));
        }
        run(getConf(), inputPath, path, outputPath, d, true, equalsIgnoreCase);
        return 0;
    }

    private ClusterClassificationDriver() {
    }

    public static void main(String[] strArr) throws Exception {
        ToolRunner.run(new Configuration(), new ClusterClassificationDriver(), strArr);
    }

    public static void run(Configuration configuration, Path path, Path path2, Path path3, Double d, boolean z, boolean z2) throws IOException, InterruptedException, ClassNotFoundException {
        if (z2) {
            classifyClusterSeq(configuration, path, path2, path3, d, z);
        } else {
            classifyClusterMR(configuration, path, path2, path3, d, z);
        }
    }

    private static void classifyClusterSeq(Configuration configuration, Path path, Path path2, Path path3, Double d, boolean z) throws IOException {
        List<Cluster> populateClusterModels = populateClusterModels(path2, configuration);
        selectCluster(path, populateClusterModels, new ClusterClassifier(populateClusterModels, ClusterClassifier.readPolicy(finalClustersPath(configuration, path2))), path3, d, z);
    }

    private static List<Cluster> populateClusterModels(Path path, Configuration configuration) throws IOException {
        ArrayList newArrayList = Lists.newArrayList();
        ForwardingIterator sequenceFileDirValueIterator = new SequenceFileDirValueIterator(finalClustersPath(configuration, path), PathType.LIST, PathFilters.partFilter(), null, false, configuration);
        while (sequenceFileDirValueIterator.hasNext()) {
            Cluster value = ((ClusterWritable) sequenceFileDirValueIterator.next()).getValue();
            value.configure(configuration);
            newArrayList.add(value);
        }
        return newArrayList;
    }

    private static Path finalClustersPath(Configuration configuration, Path path) throws IOException {
        return path.getFileSystem(configuration).listStatus(path, PathFilters.finalPartFilter())[0].getPath();
    }

    private static void selectCluster(Path path, List<Cluster> list, ClusterClassifier clusterClassifier, Path path2, Double d, boolean z) throws IOException {
        Configuration configuration = new Configuration();
        SequenceFile.Writer writer = new SequenceFile.Writer(path.getFileSystem(configuration), configuration, new Path(path2, "part-m-0"), IntWritable.class, WeightedPropertyVectorWritable.class);
        Iterator it = new SequenceFileDirIterable(path, PathType.LIST, PathFilters.logsCRCFilter(), configuration).iterator();
        while (it.hasNext()) {
            Pair pair = (Pair) it.next();
            Class<?> cls = ((Writable) pair.getFirst()).getClass();
            Vector vector = ((VectorWritable) pair.getSecond()).get();
            if (!cls.equals(NamedVector.class)) {
                if (cls.equals(Text.class)) {
                    vector = new NamedVector(vector, ((Writable) pair.getFirst()).toString());
                } else if (cls.equals(IntWritable.class)) {
                    vector = new NamedVector(vector, Integer.toString(((IntWritable) pair.getFirst()).get()));
                }
            }
            Vector classify = clusterClassifier.classify(vector);
            if (shouldClassify(classify, d)) {
                classifyAndWrite(list, d, z, writer, new VectorWritable(vector), classify);
            }
        }
        writer.close();
    }

    private static void classifyAndWrite(List<Cluster> list, Double d, boolean z, SequenceFile.Writer writer, VectorWritable vectorWritable, Vector vector) throws IOException {
        HashMap newHashMap = Maps.newHashMap();
        if (!z) {
            writeAllAboveThreshold(list, d, writer, vectorWritable, vector);
        } else {
            write(list, writer, new WeightedPropertyVectorWritable(vector.maxValue(), vectorWritable.get(), newHashMap), vector.maxValueIndex());
        }
    }

    private static void writeAllAboveThreshold(List<Cluster> list, Double d, SequenceFile.Writer writer, VectorWritable vectorWritable, Vector vector) throws IOException {
        HashMap newHashMap = Maps.newHashMap();
        for (Vector.Element element : vector.nonZeroes()) {
            if (element.get() >= d.doubleValue()) {
                write(list, writer, new WeightedPropertyVectorWritable(element.get(), vectorWritable.get(), newHashMap), element.index());
            }
        }
    }

    private static void write(List<Cluster> list, SequenceFile.Writer writer, WeightedPropertyVectorWritable weightedPropertyVectorWritable, int i) throws IOException {
        Cluster cluster = list.get(i);
        weightedPropertyVectorWritable.getProperties().put(new Text("distance"), new Text(Double.toString(((DistanceMeasureCluster) cluster).getMeasure().distance(cluster.getCenter(), weightedPropertyVectorWritable.getVector()))));
        writer.append(new IntWritable(cluster.getId()), weightedPropertyVectorWritable);
    }

    private static boolean shouldClassify(Vector vector, Double d) {
        return vector.maxValue() >= d.doubleValue();
    }

    private static void classifyClusterMR(Configuration configuration, Path path, Path path2, Path path3, Double d, boolean z) throws IOException, InterruptedException, ClassNotFoundException {
        configuration.setFloat(ClusterClassificationConfigKeys.OUTLIER_REMOVAL_THRESHOLD, d.floatValue());
        configuration.setBoolean(ClusterClassificationConfigKeys.EMIT_MOST_LIKELY, z);
        configuration.set(ClusterClassificationConfigKeys.CLUSTERS_IN, path2.toUri().toString());
        Job job = new Job(configuration, "Cluster Classification Driver running over input: " + path);
        job.setJarByClass(ClusterClassificationDriver.class);
        job.setInputFormatClass(SequenceFileInputFormat.class);
        job.setOutputFormatClass(SequenceFileOutputFormat.class);
        job.setMapperClass(ClusterClassificationMapper.class);
        job.setNumReduceTasks(0);
        job.setOutputKeyClass(IntWritable.class);
        job.setOutputValueClass(WeightedPropertyVectorWritable.class);
        FileInputFormat.addInputPath(job, path);
        FileOutputFormat.setOutputPath(job, path3);
        if (!job.waitForCompletion(true)) {
            throw new InterruptedException("Cluster Classification Driver Job failed processing " + path);
        }
    }

    public static void run(Configuration configuration, Path path, Path path2, Path path3, double d, boolean z, boolean z2) throws IOException, InterruptedException, ClassNotFoundException {
        if (z2) {
            classifyClusterSeq(configuration, path, path2, path3, Double.valueOf(d), z);
        } else {
            classifyClusterMR(configuration, path, path2, path3, Double.valueOf(d), z);
        }
    }
}
