/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.clustering.streaming.cluster;

import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.List;
import org.apache.mahout.clustering.ClusteringUtils;
import org.apache.mahout.clustering.streaming.cluster.BallKMeans;
import org.apache.mahout.clustering.streaming.cluster.DataUtils;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure;
import org.apache.mahout.math.Centroid;
import org.apache.mahout.math.ConstantVector;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.SingularValueDecomposition;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.WeightedVector;
import org.apache.mahout.math.function.Functions;
import org.apache.mahout.math.function.VectorFunction;
import org.apache.mahout.math.neighborhood.BruteSearch;
import org.apache.mahout.math.neighborhood.UpdatableSearcher;
import org.apache.mahout.math.random.MultiNormal;
import org.apache.mahout.math.random.WeightedThing;
import org.apache.mahout.math.stats.OnlineSummarizer;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;

public class BallKMeansTest {
    private static final int NUM_DATA_POINTS = 10000;
    private static final int NUM_DIMENSIONS = 4;
    private static final int NUM_ITERATIONS = 20;
    private static final double DISTRIBUTION_RADIUS = 0.01;
    private static Pair<List<Centroid>, List<Centroid>> syntheticData;
    private static final int K1 = 100;

    @BeforeClass
    public static void setUp() {
        RandomUtils.useTestSeed();
        syntheticData = DataUtils.sampleMultiNormalHypercube(4, 10000, 0.01);
    }

    @Test
    public void testClusteringMultipleRuns() {
        for (int i = 1; i <= 10; ++i) {
            BallKMeans clusterer = new BallKMeans((UpdatableSearcher)new BruteSearch((DistanceMeasure)new SquaredEuclideanDistanceMeasure()), 16, 20, true, i);
            clusterer.cluster((List)syntheticData.getFirst());
            double costKMeansPlusPlus = ClusteringUtils.totalClusterCost((Iterable)((Iterable)syntheticData.getFirst()), (Iterable)clusterer);
            clusterer = new BallKMeans((UpdatableSearcher)new BruteSearch((DistanceMeasure)new SquaredEuclideanDistanceMeasure()), 16, 20, false, i);
            clusterer.cluster((List)syntheticData.getFirst());
            double costKMeansRandom = ClusteringUtils.totalClusterCost((Iterable)((Iterable)syntheticData.getFirst()), (Iterable)clusterer);
            System.out.printf("%d runs; kmeans++: %f; random: %f\n", i, costKMeansPlusPlus, costKMeansRandom);
            Assert.assertTrue((String)"kmeans++ cost should be less than random cost", (costKMeansPlusPlus < costKMeansRandom ? 1 : 0) != 0);
        }
    }

    @Test
    public void testClustering() {
        BruteSearch searcher = new BruteSearch((DistanceMeasure)new SquaredEuclideanDistanceMeasure());
        BallKMeans clusterer = new BallKMeans((UpdatableSearcher)searcher, 16, 20);
        long startTime = System.currentTimeMillis();
        Pair<List<Centroid>, List<Centroid>> data = syntheticData;
        clusterer.cluster((List)data.getFirst());
        long endTime = System.currentTimeMillis();
        long hash = 0L;
        for (Centroid centroid : (List)data.getFirst()) {
            for (Vector.Element element : centroid.all()) {
                hash = 31L * hash + (long)(17 * element.index()) + (long)Double.toHexString(element.get()).hashCode();
            }
        }
        System.out.printf("Hash = %08x\n", hash);
        Assert.assertEquals((String)"Total weight not preserved", (double)ClusteringUtils.totalWeight((Iterable)((Iterable)syntheticData.getFirst())), (double)ClusteringUtils.totalWeight((Iterable)clusterer), (double)1.0E-9);
        OnlineSummarizer summarizer = new OnlineSummarizer();
        for (Vector mean : (List)syntheticData.getSecond()) {
            WeightedThing v = (WeightedThing)searcher.search(mean, 1).get(0);
            summarizer.add(v.getWeight());
        }
        Assert.assertTrue((String)String.format("Median weight [%f] too large [>%f]", summarizer.getMedian(), 0.01), (summarizer.getMedian() < 0.01 ? 1 : 0) != 0);
        double clusterTime = (double)(endTime - startTime) / 1000.0;
        System.out.printf("%s\n%.2f for clustering\n%.1f us per row\n\n", searcher.getClass().getName(), clusterTime, clusterTime / (double)((List)syntheticData.getFirst()).size() * 1000000.0);
        double[] cornerWeights = new double[16];
        BruteSearch trueFinder = new BruteSearch((DistanceMeasure)new EuclideanDistanceMeasure());
        for (Vector trueCluster : (List)syntheticData.getSecond()) {
            trueFinder.add(trueCluster);
        }
        for (Centroid centroid : clusterer) {
            WeightedThing closest = (WeightedThing)trueFinder.search((Vector)centroid, 1).get(0);
            int n = ((Centroid)closest.getValue()).getIndex();
            cornerWeights[n] = cornerWeights[n] + centroid.getWeight();
        }
        int expectedNumPoints = 625;
        for (double v : cornerWeights) {
            System.out.printf("%f ", v);
        }
        System.out.println();
        for (double v : cornerWeights) {
            Assert.assertEquals((double)expectedNumPoints, (double)v, (double)0.0);
        }
    }

    @Test
    public void testInitialization() {
        List<? extends WeightedVector> data = BallKMeansTest.cubishTestData(0.01);
        BallKMeans r = new BallKMeans((UpdatableSearcher)new BruteSearch((DistanceMeasure)new SquaredEuclideanDistanceMeasure()), 6, 20);
        r.cluster(data);
        DenseMatrix x = new DenseMatrix(6, 5);
        int row = 0;
        for (Centroid c : r) {
            x.viewRow(row).assign(c.viewPart(0, 5));
            ++row;
        }
        Vector columnNorms = x.aggregateColumns(new VectorFunction(){

            public double apply(Vector f) {
                return Math.abs(f.minValue()) + Math.abs(f.maxValue() - 6.0) + Math.abs(f.norm(1.0) - 6.0);
            }
        });
        Assert.assertEquals((double)0.0, (double)(columnNorms.norm(1.0) / (double)columnNorms.size()), (double)0.1);
        SingularValueDecomposition svd = new SingularValueDecomposition((Matrix)x);
        Vector s = svd.getS().viewDiagonal().assign(Functions.div((double)6.0));
        Assert.assertEquals((double)5.0, (double)s.getLengthSquared(), (double)0.05);
        Assert.assertEquals((double)5.0, (double)s.norm(1.0), (double)0.05);
    }

    private static List<? extends WeightedVector> cubishTestData(double radius) {
        int i;
        ArrayList data = Lists.newArrayListWithCapacity((int)5100);
        int row = 0;
        MultiNormal g = new MultiNormal(radius, (Vector)new ConstantVector(0.0, 10));
        for (i = 0; i < 100; ++i) {
            data.add(new WeightedVector(g.sample(), 1.0, row++));
        }
        for (i = 0; i < 5; ++i) {
            DenseVector m = new DenseVector(10);
            m.set(i, 6.0);
            MultiNormal gx = new MultiNormal(radius, (Vector)m);
            for (int j = 0; j < 1000; ++j) {
                data.add(new WeightedVector(gx.sample(), 1.0, row++));
            }
        }
        return data;
    }
}

