package org.apache.mahout.clustering.streaming.mapreduce;

import com.google.common.base.Function;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.mrunit.mapreduce.MapDriver;
import org.apache.hadoop.mrunit.mapreduce.MapReduceDriver;
import org.apache.hadoop.mrunit.mapreduce.ReduceDriver;
import org.apache.mahout.clustering.ClusteringUtils;
import org.apache.mahout.clustering.streaming.cluster.DataUtils;
import org.apache.mahout.clustering.streaming.cluster.StreamingKMeans;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
import org.apache.mahout.math.Centroid;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.neighborhood.BruteSearch;
import org.apache.mahout.math.neighborhood.FastProjectionSearch;
import org.apache.mahout.math.neighborhood.LocalitySensitiveHashSearch;
import org.apache.mahout.math.neighborhood.ProjectionSearch;
import org.apache.mahout.math.random.WeightedThing;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

@RunWith(Parameterized.class)
/* loaded from: input_file:org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansTestMR.class */
public class StreamingKMeansTestMR {
    private static final int NUM_PROJECTIONS = 3;
    private static final int SEARCH_SIZE = 5;
    private static final int MAX_NUM_ITERATIONS = 10;
    private static final double DISTANCE_CUTOFF = 1.0E-6d;
    private final String searcherClassName;
    private final String distanceMeasureClassName;
    private static final int NUM_DIMENSIONS = 8;
    private static final int NUM_DATA_POINTS = 32768;
    private static final Pair<List<Centroid>, List<Centroid>> syntheticData = DataUtils.sampleMultiNormalHypercube(NUM_DIMENSIONS, NUM_DATA_POINTS, 1.0E-4d);

    public StreamingKMeansTestMR(String str, String str2) {
        this.searcherClassName = str;
        this.distanceMeasureClassName = str2;
    }

    private void configure(Configuration configuration) {
        configuration.set("distanceMeasure", this.distanceMeasureClassName);
        configuration.setInt("searchSize", SEARCH_SIZE);
        configuration.setInt("numProjections", NUM_PROJECTIONS);
        configuration.set("searcherClass", this.searcherClassName);
        configuration.setInt("numClusters", 256);
        configuration.setInt("estimatedNumMapClusters", 256 * ((int) Math.log(32768.0d)));
        configuration.setFloat("estimatedDistanceCutoff", 1.0E-6f);
        configuration.setInt("maxNumIterations", MAX_NUM_ITERATIONS);
        configuration.setBoolean("reduceStreamingKMeans", true);
    }

    @Parameterized.Parameters
    public static List<Object[]> generateData() {
        return Arrays.asList(new Object[]{ProjectionSearch.class.getName(), SquaredEuclideanDistanceMeasure.class.getName()}, new Object[]{FastProjectionSearch.class.getName(), SquaredEuclideanDistanceMeasure.class.getName()}, new Object[]{LocalitySensitiveHashSearch.class.getName(), SquaredEuclideanDistanceMeasure.class.getName()});
    }

    @Test
    public void testHypercubeMapper() throws IOException {
        MapDriver newMapDriver = MapDriver.newMapDriver(new StreamingKMeansMapper());
        configure(newMapDriver.getConfiguration());
        System.out.printf("%s mapper test\n", newMapDriver.getConfiguration().get("searcherClass"));
        Iterator it = ((List) syntheticData.getFirst()).iterator();
        while (it.hasNext()) {
            newMapDriver.addInput(new IntWritable(0), new VectorWritable((Centroid) it.next()));
        }
        List run = newMapDriver.run();
        BruteSearch bruteSearch = new BruteSearch(new SquaredEuclideanDistanceMeasure());
        Iterator it2 = run.iterator();
        while (it2.hasNext()) {
            bruteSearch.add(((CentroidWritable) ((org.apache.hadoop.mrunit.types.Pair) it2.next()).getSecond()).getCentroid());
        }
        System.out.printf("Clustered the data into %d clusters\n", Integer.valueOf(run.size()));
        Iterator it3 = ((List) syntheticData.getSecond()).iterator();
        while (it3.hasNext()) {
            WeightedThing weightedThing = (WeightedThing) bruteSearch.search((Centroid) it3.next(), 1).get(0);
            Assert.assertTrue("Weight " + weightedThing.getWeight() + " not less than 0.5", weightedThing.getWeight() < 0.5d);
        }
    }

    @Test
    public void testMapperVsLocal() throws IOException {
        MapDriver newMapDriver = MapDriver.newMapDriver(new StreamingKMeansMapper());
        Configuration configuration = newMapDriver.getConfiguration();
        configure(configuration);
        System.out.printf("%s mapper vs local test\n", newMapDriver.getConfiguration().get("searcherClass"));
        Iterator it = ((List) syntheticData.getFirst()).iterator();
        while (it.hasNext()) {
            newMapDriver.addInput(new IntWritable(0), new VectorWritable((Centroid) it.next()));
        }
        ArrayList newArrayList = Lists.newArrayList();
        Iterator it2 = newMapDriver.run().iterator();
        while (it2.hasNext()) {
            newArrayList.add(((CentroidWritable) ((org.apache.hadoop.mrunit.types.Pair) it2.next()).getSecond()).getCentroid());
        }
        StreamingKMeans streamingKMeans = new StreamingKMeans(StreamingKMeansUtilsMR.searcherFromConfiguration(configuration), newMapDriver.getConfiguration().getInt("estimatedNumMapClusters", -1), 1.0E-6d);
        streamingKMeans.cluster((Iterable) syntheticData.getFirst());
        ArrayList newArrayList2 = Lists.newArrayList();
        Iterator it3 = streamingKMeans.iterator();
        while (it3.hasNext()) {
            newArrayList2.add((Centroid) it3.next());
        }
        StreamingKMeans streamingKMeans2 = new StreamingKMeans(StreamingKMeansUtilsMR.searcherFromConfiguration(configuration), 256 * ((int) Math.log(32768.0d)), 1.0E-6d);
        Iterator it4 = ((List) syntheticData.getFirst()).iterator();
        while (it4.hasNext()) {
            streamingKMeans2.cluster((Centroid) it4.next());
        }
        ArrayList newArrayList3 = Lists.newArrayList();
        Iterator it5 = streamingKMeans2.iterator();
        while (it5.hasNext()) {
            newArrayList3.add((Centroid) it5.next());
        }
        double d = ClusteringUtils.totalClusterCost((Iterable) syntheticData.getFirst(), newArrayList);
        double d2 = ClusteringUtils.totalClusterCost((Iterable) syntheticData.getFirst(), newArrayList2);
        double d3 = ClusteringUtils.totalClusterCost((Iterable) syntheticData.getFirst(), newArrayList3);
        System.out.printf("[Total cost] Mapper %f [%d] Local %f [%d] Perpoint local %f [%d];[ratio m-vs-l %f] [ratio pp-vs-l %f]\n", Double.valueOf(d), Integer.valueOf(newArrayList.size()), Double.valueOf(d2), Integer.valueOf(newArrayList2.size()), Double.valueOf(d3), Integer.valueOf(newArrayList3.size()), Double.valueOf(d / d2), Double.valueOf(d3 / d2));
        Assert.assertEquals("Mapper StreamingKMeans / Batch local StreamingKMeans total cost ratio too far from 1", 1.0d, d / d2, 0.8d);
        Assert.assertEquals("One by one local StreamingKMeans / Batch local StreamingKMeans total cost ratio too high", 1.0d, d3 / d2, 0.8d);
    }

    @Test
    public void testHypercubeReducer() throws IOException {
        ReduceDriver newReduceDriver = ReduceDriver.newReduceDriver(new StreamingKMeansReducer());
        Configuration configuration = newReduceDriver.getConfiguration();
        configure(configuration);
        System.out.printf("%s reducer test\n", configuration.get("searcherClass"));
        StreamingKMeans streamingKMeans = new StreamingKMeans(StreamingKMeansUtilsMR.searcherFromConfiguration(configuration), 256 * ((int) Math.log(32768.0d)), 1.0E-6d);
        long currentTimeMillis = System.currentTimeMillis();
        streamingKMeans.cluster((Iterable) syntheticData.getFirst());
        System.out.printf("%f [s]\n", Double.valueOf((System.currentTimeMillis() - currentTimeMillis) / 1000.0d));
        ArrayList newArrayList = Lists.newArrayList();
        int i = 0;
        Iterator it = streamingKMeans.iterator();
        while (it.hasNext()) {
            Centroid centroid = (Centroid) it.next();
            newArrayList.add(new CentroidWritable(centroid));
            i = (int) (i + centroid.getWeight());
        }
        newReduceDriver.addInput(new IntWritable(0), newArrayList);
        testReducerResults(i, newReduceDriver.run());
    }

    @Test
    public void testHypercubeMapReduce() throws IOException {
        MapReduceDriver mapReduceDriver = new MapReduceDriver(new StreamingKMeansMapper(), new StreamingKMeansReducer());
        Configuration configuration = mapReduceDriver.getConfiguration();
        configure(configuration);
        System.out.printf("%s full test\n", configuration.get("searcherClass"));
        Iterator it = ((List) syntheticData.getFirst()).iterator();
        while (it.hasNext()) {
            mapReduceDriver.addInput(new IntWritable(0), new VectorWritable((Centroid) it.next()));
        }
        testReducerResults(((List) syntheticData.getFirst()).size(), mapReduceDriver.run());
    }

    @Test
    public void testHypercubeMapReduceRunSequentially() throws Exception {
        Configuration configuration = new Configuration();
        configure(configuration);
        configuration.set("method", "sequential");
        Path path = new Path("testInput");
        Path path2 = new Path("testOutput");
        StreamingKMeansUtilsMR.writeVectorsToSequenceFile((Iterable) syntheticData.getFirst(), path, configuration);
        StreamingKMeansDriver.run(configuration, path, path2);
        testReducerResults(((List) syntheticData.getFirst()).size(), Lists.newArrayList(Iterables.transform(new SequenceFileIterable(path2, configuration), new Function<Pair<IntWritable, CentroidWritable>, org.apache.hadoop.mrunit.types.Pair<IntWritable, CentroidWritable>>() { // from class: org.apache.mahout.clustering.streaming.mapreduce.StreamingKMeansTestMR.1
            public org.apache.hadoop.mrunit.types.Pair<IntWritable, CentroidWritable> apply(Pair<IntWritable, CentroidWritable> pair) {
                return new org.apache.hadoop.mrunit.types.Pair<>(pair.getFirst(), pair.getSecond());
            }
        })));
    }

    private static void testReducerResults(int i, List<org.apache.hadoop.mrunit.types.Pair<IntWritable, CentroidWritable>> list) {
        double d = i / 256;
        int i2 = 0;
        int i3 = 0;
        int i4 = 0;
        for (org.apache.hadoop.mrunit.types.Pair<IntWritable, CentroidWritable> pair : list) {
            if (((CentroidWritable) pair.getSecond()).getCentroid().getWeight() != d) {
                System.out.printf("Unbalanced weight %f in centroid %d\n", Double.valueOf(((CentroidWritable) pair.getSecond()).getCentroid().getWeight()), Integer.valueOf(((CentroidWritable) pair.getSecond()).getCentroid().getIndex()));
                i3++;
            }
            Assert.assertEquals("Final centroid index is invalid", i2, ((IntWritable) pair.getFirst()).get());
            i4 = (int) (i4 + ((CentroidWritable) pair.getSecond()).getCentroid().getWeight());
            i2++;
        }
        System.out.printf("%d clusters are unbalanced\n", Integer.valueOf(i3));
        Assert.assertEquals("Invalid total weight", i, i4);
        Assert.assertEquals("Invalid number of clusters", 256L, i2);
    }
}
