package org.apache.mahout.clustering.streaming.cluster;

import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.apache.mahout.clustering.ClusteringUtils;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure;
import org.apache.mahout.math.Centroid;
import org.apache.mahout.math.neighborhood.BruteSearch;
import org.apache.mahout.math.neighborhood.FastProjectionSearch;
import org.apache.mahout.math.neighborhood.ProjectionSearch;
import org.apache.mahout.math.neighborhood.UpdatableSearcher;
import org.apache.mahout.math.random.WeightedThing;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

@RunWith(Parameterized.class)
/* loaded from: input_file:org/apache/mahout/clustering/streaming/cluster/StreamingKMeansTest.class */
public class StreamingKMeansTest {
    private static final int NUM_PROJECTIONS = 2;
    private static final int SEARCH_SIZE = 10;
    private UpdatableSearcher searcher;
    private boolean allAtOnce;
    private static final int NUM_DIMENSIONS = 6;
    private static final int NUM_DATA_POINTS = 65536;
    private static final Pair<List<Centroid>, List<Centroid>> syntheticData = DataUtils.sampleMultiNormalHypercube(NUM_DIMENSIONS, NUM_DATA_POINTS);

    public StreamingKMeansTest(UpdatableSearcher updatableSearcher, boolean z) {
        this.searcher = updatableSearcher;
        this.allAtOnce = z;
    }

    @Parameterized.Parameters
    public static List<Object[]> generateData() {
        return Arrays.asList(new Object[]{new ProjectionSearch(new SquaredEuclideanDistanceMeasure(), NUM_PROJECTIONS, SEARCH_SIZE), true}, new Object[]{new FastProjectionSearch(new SquaredEuclideanDistanceMeasure(), NUM_PROJECTIONS, SEARCH_SIZE), true}, new Object[]{new ProjectionSearch(new SquaredEuclideanDistanceMeasure(), NUM_PROJECTIONS, SEARCH_SIZE), false}, new Object[]{new FastProjectionSearch(new SquaredEuclideanDistanceMeasure(), NUM_PROJECTIONS, SEARCH_SIZE), false});
    }

    @Test
    public void testAverageDistanceCutoff() {
        double d = 0.0d;
        double d2 = 0.0d;
        System.out.printf("Distance cutoff for %s\n", this.searcher.getClass().getName());
        for (int i = 0; i < 1; i++) {
            this.searcher.clear();
            int log = ((int) Math.log(((List) syntheticData.getFirst()).size())) * 64;
            double estimateDistanceCutoff = ClusteringUtils.estimateDistanceCutoff((Iterable) syntheticData.getFirst(), this.searcher.getDistanceMeasure(), 100);
            System.out.printf("[%d] Generated synthetic data [magic] %f [estimate] %f\n", Integer.valueOf(i), Double.valueOf(1.0E-6d), Double.valueOf(estimateDistanceCutoff));
            StreamingKMeans streamingKMeans = new StreamingKMeans(this.searcher, log, estimateDistanceCutoff);
            streamingKMeans.cluster((Iterable) syntheticData.getFirst());
            d += streamingKMeans.getDistanceCutoff();
            d2 += streamingKMeans.getNumClusters();
            System.out.printf("[%d] %f\n", Integer.valueOf(i), Double.valueOf(streamingKMeans.getDistanceCutoff()));
        }
        System.out.printf("Final: distanceCutoff: %f estNumClusters: %f\n", Double.valueOf(d / 1), Double.valueOf(d2 / 1));
    }

    @Test
    public void testClustering() {
        this.searcher.clear();
        int log = ((int) Math.log(((List) syntheticData.getFirst()).size())) * 64;
        System.out.printf("k log n = %d\n", Integer.valueOf(log));
        StreamingKMeans streamingKMeans = new StreamingKMeans(this.searcher, log, ClusteringUtils.estimateDistanceCutoff((Iterable) syntheticData.getFirst(), this.searcher.getDistanceMeasure(), 100));
        long currentTimeMillis = System.currentTimeMillis();
        if (this.allAtOnce) {
            streamingKMeans.cluster((Iterable) syntheticData.getFirst());
        } else {
            Iterator it = ((List) syntheticData.getFirst()).iterator();
            while (it.hasNext()) {
                streamingKMeans.cluster((Centroid) it.next());
            }
        }
        long currentTimeMillis2 = System.currentTimeMillis();
        System.out.printf("%s %s\n", this.searcher.getClass().getName(), this.searcher.getDistanceMeasure().getClass().getName());
        System.out.printf("Total number of clusters %d\n", Integer.valueOf(streamingKMeans.getNumClusters()));
        System.out.printf("Weights: %f %f\n", Double.valueOf(ClusteringUtils.totalWeight((Iterable) syntheticData.getFirst())), Double.valueOf(ClusteringUtils.totalWeight(streamingKMeans)));
        Assert.assertEquals("Total weight not preserved", ClusteringUtils.totalWeight((Iterable) syntheticData.getFirst()), ClusteringUtils.totalWeight(streamingKMeans), 1.0E-9d);
        double d = 0.0d;
        Iterator it2 = ((List) syntheticData.getSecond()).iterator();
        while (it2.hasNext()) {
            d = Math.max(((WeightedThing) this.searcher.search((Centroid) it2.next(), 1).get(0)).getWeight(), d);
        }
        Assert.assertTrue("Maximum weight too large " + d, d < 0.05d);
        double d2 = (currentTimeMillis2 - currentTimeMillis) / 1000.0d;
        System.out.printf("%s\n%.2f for clustering\n%.1f us per row\n\n", this.searcher.getClass().getName(), Double.valueOf(d2), Double.valueOf((d2 / ((List) syntheticData.getFirst()).size()) * 1000000.0d));
        double[] dArr = new double[64];
        BruteSearch bruteSearch = new BruteSearch(new EuclideanDistanceMeasure());
        Iterator it3 = ((List) syntheticData.getSecond()).iterator();
        while (it3.hasNext()) {
            bruteSearch.add((Centroid) it3.next());
        }
        Iterator it4 = streamingKMeans.iterator();
        while (it4.hasNext()) {
            Centroid centroid = (Centroid) it4.next();
            int index = ((Centroid) ((WeightedThing) bruteSearch.search(centroid, 1).get(0)).getValue()).getIndex();
            dArr[index] = dArr[index] + centroid.getWeight();
        }
        for (double d3 : dArr) {
            System.out.printf("%f ", Double.valueOf(d3));
        }
        System.out.println();
        for (double d4 : dArr) {
            Assert.assertEquals(1024, d4, 0.0d);
        }
    }
}
