package org.apache.mahout.clustering.kmeans;

import com.google.common.collect.Lists;
import com.google.common.io.Closeables;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.clustering.ClusteringTestUtils;
import org.apache.mahout.clustering.canopy.CanopyDriver;
import org.apache.mahout.clustering.classify.WeightedVectorWritable;
import org.apache.mahout.clustering.iterator.ClusterWritable;
import org.apache.mahout.common.DummyOutputCollector;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.SequentialAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/clustering/kmeans/TestKmeansClustering.class */
public final class TestKmeansClustering extends MahoutTestCase {
    public static final double[][] REFERENCE = {new double[]{1.0d, 1.0d}, new double[]{2.0d, 1.0d}, new double[]{1.0d, 2.0d}, new double[]{2.0d, 2.0d}, new double[]{3.0d, 3.0d}, new double[]{4.0d, 4.0d}, new double[]{5.0d, 4.0d}, new double[]{4.0d, 5.0d}, new double[]{5.0d, 5.0d}};
    private static final int[][] EXPECTED_NUM_POINTS = {new int[]{9}, new int[]{4, 5}, new int[]{4, 4, 1}, new int[]{1, 2, 1, 5}, new int[]{1, 1, 1, 2, 4}, new int[]{1, 1, 1, 1, 1, 4}, new int[]{1, 1, 1, 1, 1, 2, 2}, new int[]{1, 1, 1, 1, 1, 1, 2, 1}, new int[]{1, 1, 1, 1, 1, 1, 1, 1, 1}};
    private FileSystem fs;

    @Override // org.apache.mahout.common.MahoutTestCase
    @Before
    public void setUp() throws Exception {
        super.setUp();
        this.fs = FileSystem.get(new Configuration());
    }

    public static List<VectorWritable> getPointsWritable(double[][] dArr) {
        ArrayList newArrayList = Lists.newArrayList();
        for (double[] dArr2 : dArr) {
            RandomAccessSparseVector randomAccessSparseVector = new RandomAccessSparseVector(dArr2.length);
            randomAccessSparseVector.assign(dArr2);
            newArrayList.add(new VectorWritable(randomAccessSparseVector));
        }
        return newArrayList;
    }

    public static List<Vector> getPoints(double[][] dArr) {
        ArrayList newArrayList = Lists.newArrayList();
        for (double[] dArr2 : dArr) {
            SequentialAccessSparseVector sequentialAccessSparseVector = new SequentialAccessSparseVector(dArr2.length);
            sequentialAccessSparseVector.assign(dArr2);
            newArrayList.add(sequentialAccessSparseVector);
        }
        return newArrayList;
    }

    @Test
    public void testKMeansSeqJob() throws Exception {
        EuclideanDistanceMeasure euclideanDistanceMeasure = new EuclideanDistanceMeasure();
        List<VectorWritable> pointsWritable = getPointsWritable(REFERENCE);
        Path testTempDirPath = getTestTempDirPath("points");
        Path testTempDirPath2 = getTestTempDirPath("clusters");
        Configuration configuration = new Configuration();
        ClusteringTestUtils.writePointsToFile(pointsWritable, true, new Path(testTempDirPath, "file1"), this.fs, configuration);
        ClusteringTestUtils.writePointsToFile(pointsWritable, true, new Path(testTempDirPath, "file2"), this.fs, configuration);
        for (int i = 1; i < pointsWritable.size(); i++) {
            System.out.println("testKMeansMRJob k= " + i);
            Path path = new Path(testTempDirPath2, "part-00000");
            SequenceFile.Writer writer = new SequenceFile.Writer(FileSystem.get(path.toUri(), configuration), configuration, path, Text.class, Kluster.class);
            for (int i2 = 0; i2 < i + 1; i2++) {
                try {
                    Kluster kluster = new Kluster(pointsWritable.get(i2).get(), i2, euclideanDistanceMeasure);
                    kluster.observe(kluster.getCenter(), 1.0d);
                    writer.append(new Text(kluster.getIdentifier()), kluster);
                } finally {
                    Closeables.closeQuietly(writer);
                }
            }
            Path testTempDirPath3 = getTestTempDirPath("output");
            new KMeansDriver().run(new String[]{optKey("input"), testTempDirPath.toString(), optKey("clusters"), testTempDirPath2.toString(), optKey("output"), testTempDirPath3.toString(), optKey("distanceMeasure"), EuclideanDistanceMeasure.class.getName(), optKey("convergenceDelta"), "0.001", optKey("maxIter"), "2", optKey("clustering"), optKey("overwrite"), optKey("method"), "sequential"});
            Path path2 = new Path(testTempDirPath3, "clusteredPoints");
            int[] iArr = EXPECTED_NUM_POINTS[i];
            DummyOutputCollector dummyOutputCollector = new DummyOutputCollector();
            Iterator it = new SequenceFileIterable(new Path(path2, "part-m-0"), configuration).iterator();
            while (it.hasNext()) {
                Pair pair = (Pair) it.next();
                dummyOutputCollector.collect((DummyOutputCollector) pair.getFirst(), (WritableComparable) pair.getSecond());
            }
            assertEquals("clusters[" + i + ']', iArr.length, dummyOutputCollector.getKeys().size());
        }
    }

    @Test
    public void testKMeansMRJob() throws Exception {
        EuclideanDistanceMeasure euclideanDistanceMeasure = new EuclideanDistanceMeasure();
        List<VectorWritable> pointsWritable = getPointsWritable(REFERENCE);
        Path testTempDirPath = getTestTempDirPath("points");
        Path testTempDirPath2 = getTestTempDirPath("clusters");
        Configuration configuration = new Configuration();
        ClusteringTestUtils.writePointsToFile(pointsWritable, true, new Path(testTempDirPath, "file1"), this.fs, configuration);
        ClusteringTestUtils.writePointsToFile(pointsWritable, true, new Path(testTempDirPath, "file2"), this.fs, configuration);
        for (int i = 1; i < pointsWritable.size(); i += 3) {
            System.out.println("testKMeansMRJob k= " + i);
            Path path = new Path(testTempDirPath2, "part-00000");
            SequenceFile.Writer writer = new SequenceFile.Writer(FileSystem.get(path.toUri(), configuration), configuration, path, Text.class, Kluster.class);
            for (int i2 = 0; i2 < i + 1; i2++) {
                try {
                    Kluster kluster = new Kluster(pointsWritable.get(i2).get(), i2, euclideanDistanceMeasure);
                    kluster.observe(kluster.getCenter(), 1.0d);
                    writer.append(new Text(kluster.getIdentifier()), kluster);
                } finally {
                    Closeables.closeQuietly(writer);
                }
            }
            Path testTempDirPath3 = getTestTempDirPath("output");
            ToolRunner.run(new Configuration(), new KMeansDriver(), new String[]{optKey("input"), testTempDirPath.toString(), optKey("clusters"), testTempDirPath2.toString(), optKey("output"), testTempDirPath3.toString(), optKey("distanceMeasure"), EuclideanDistanceMeasure.class.getName(), optKey("convergenceDelta"), "0.001", optKey("maxIter"), "2", optKey("clustering"), optKey("overwrite")});
            Path path2 = new Path(testTempDirPath3, "clusteredPoints");
            int[] iArr = EXPECTED_NUM_POINTS[i];
            DummyOutputCollector dummyOutputCollector = new DummyOutputCollector();
            Iterator it = new SequenceFileIterable(new Path(path2, "part-m-00000"), configuration).iterator();
            while (it.hasNext()) {
                Pair pair = (Pair) it.next();
                dummyOutputCollector.collect((DummyOutputCollector) pair.getFirst(), (WritableComparable) pair.getSecond());
            }
            assertEquals("clusters[" + i + ']', iArr.length, dummyOutputCollector.getKeys().size());
        }
    }

    @Test
    public void testKMeansWithCanopyClusterInput() throws Exception {
        List<VectorWritable> pointsWritable = getPointsWritable(REFERENCE);
        Path testTempDirPath = getTestTempDirPath("points");
        Configuration configuration = new Configuration();
        ClusteringTestUtils.writePointsToFile(pointsWritable, true, new Path(testTempDirPath, "file1"), this.fs, configuration);
        ClusteringTestUtils.writePointsToFile(pointsWritable, true, new Path(testTempDirPath, "file2"), this.fs, configuration);
        Path testTempDirPath2 = getTestTempDirPath("output");
        CanopyDriver.run(configuration, testTempDirPath, testTempDirPath2, new ManhattanDistanceMeasure(), 3.1d, 2.1d, false, 0.0d, false);
        DummyOutputCollector dummyOutputCollector = new DummyOutputCollector();
        for (FileStatus fileStatus : FileSystem.get(configuration).globStatus(new Path(testTempDirPath2, "clusters-0-final/*-0*"))) {
            Iterator it = new SequenceFileIterable(fileStatus.getPath(), configuration).iterator();
            while (it.hasNext()) {
                Pair pair = (Pair) it.next();
                dummyOutputCollector.collect((DummyOutputCollector) pair.getFirst(), (WritableComparable) pair.getSecond());
            }
        }
        boolean z = false;
        boolean z2 = false;
        int i = 0;
        Iterator it2 = dummyOutputCollector.getKeys().iterator();
        while (it2.hasNext()) {
            i++;
            List value = dummyOutputCollector.getValue((Text) it2.next());
            assertEquals("non-singleton centroid!", 1L, value.size());
            Vector center = ((ClusterWritable) value.get(0)).getValue().getCenter();
            assertEquals("cetriod vector is wrong length", 2L, center.size());
            if (Math.abs(center.get(0) - 1.5d) < 1.0E-6d && Math.abs(center.get(1) - 1.5d) < 1.0E-6d && !z) {
                z = true;
            } else if (Math.abs(center.get(0) - 4.333333333333334d) >= 1.0E-6d || Math.abs(center.get(1) - 4.333333333333334d) >= 1.0E-6d || z2) {
                assertTrue("got unexpected center: " + center + " [" + center.getClass().toString() + "]", false);
            } else {
                z2 = true;
            }
        }
        assertEquals("got unexpected number of centers", 2L, i);
        Path path = new Path(testTempDirPath2, "kmeans");
        KMeansDriver.run(testTempDirPath, new Path(testTempDirPath2, "clusters-0-final"), path, new EuclideanDistanceMeasure(), 0.001d, 10, true, 0.0d, false);
        Path path2 = new Path(path, "clusteredPoints");
        DummyOutputCollector dummyOutputCollector2 = new DummyOutputCollector();
        Iterator it3 = new SequenceFileIterable(new Path(path2, "part-m-00000"), configuration).iterator();
        while (it3.hasNext()) {
            Pair pair2 = (Pair) it3.next();
            dummyOutputCollector2.collect((DummyOutputCollector) pair2.getFirst(), (WritableComparable) pair2.getSecond());
        }
        Iterator it4 = dummyOutputCollector2.getKeys().iterator();
        while (it4.hasNext()) {
            List value2 = dummyOutputCollector2.getValue((IntWritable) it4.next());
            assertTrue("empty cluster!", value2.size() != 0);
            if (((WeightedVectorWritable) value2.get(0)).getVector().get(0) <= 2.0d) {
                Iterator it5 = value2.iterator();
                while (it5.hasNext()) {
                    Vector vector = ((WeightedVectorWritable) it5.next()).getVector();
                    assertTrue("bad cluster!", vector.get(vector.maxValueIndex()) <= 2.0d);
                }
                assertEquals("Wrong size cluster", 4L, value2.size());
            } else {
                Iterator it6 = value2.iterator();
                while (it6.hasNext()) {
                    Vector vector2 = ((WeightedVectorWritable) it6.next()).getVector();
                    assertTrue("bad cluster!", vector2.get(vector2.minValueIndex()) > 2.0d);
                }
                assertEquals("Wrong size cluster", 5L, value2.size());
            }
        }
    }
}
