package org.apache.mahout.clustering.meanshift;

import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.clustering.ClusteringTestUtils;
import org.apache.mahout.clustering.iterator.ClusterWritable;
import org.apache.mahout.common.DummyRecordWriter;
import org.apache.mahout.common.HadoopUtil;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterator;
import org.apache.mahout.common.kernel.IKernelProfile;
import org.apache.mahout.common.kernel.TriangularKernelProfile;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/clustering/meanshift/TestMeanShift.class */
public final class TestMeanShift extends MahoutTestCase {
    private Vector[] raw = null;
    private final DistanceMeasure euclideanDistanceMeasure = new EuclideanDistanceMeasure();
    private final IKernelProfile kernelProfile = new TriangularKernelProfile();

    private static void printCanopies(Iterable<MeanShiftCanopy> iterable) {
        Iterator<MeanShiftCanopy> it = iterable.iterator();
        while (it.hasNext()) {
            System.out.println(it.next().asFormatString((String[]) null));
        }
    }

    private void printImage(Iterable<MeanShiftCanopy> iterable) {
        char[][] cArr = new char[10][10];
        for (char[] cArr2 : cArr) {
            for (int i = 0; i < cArr[0].length; i++) {
                cArr2[i] = ' ';
            }
        }
        for (MeanShiftCanopy meanShiftCanopy : iterable) {
            int id = 65 + meanShiftCanopy.getId();
            Iterator it = meanShiftCanopy.getBoundPoints().toList().iterator();
            while (it.hasNext()) {
                Vector vector = this.raw[((Integer) it.next()).intValue()];
                cArr[(int) vector.getQuick(0)][(int) vector.getQuick(1)] = (char) id;
            }
        }
        for (char[] cArr3 : cArr) {
            System.out.println(cArr3);
        }
    }

    private List<MeanShiftCanopy> getInitialCanopies() {
        int i = 0;
        ArrayList newArrayList = Lists.newArrayList();
        for (Vector vector : this.raw) {
            int i2 = i;
            i++;
            newArrayList.add(new MeanShiftCanopy(vector, i2, this.euclideanDistanceMeasure));
        }
        return newArrayList;
    }

    @Override // org.apache.mahout.common.MahoutTestCase
    @Before
    public void setUp() throws Exception {
        super.setUp();
        this.raw = new Vector[100];
        for (int i = 0; i < 10; i++) {
            for (int i2 = 0; i2 < 10; i2++) {
                int i3 = (i * 10) + i2;
                Vector denseVector = new DenseVector(3);
                denseVector.setQuick(0, i);
                denseVector.setQuick(1, i2);
                if (i == i2) {
                    denseVector.setQuick(2, 9.0d);
                } else if (i + i2 == 9) {
                    denseVector.setQuick(2, 4.5d);
                }
                this.raw[i3] = denseVector;
            }
        }
    }

    @Test
    public void testReferenceImplementation() {
        MeanShiftCanopyClusterer meanShiftCanopyClusterer = new MeanShiftCanopyClusterer(new EuclideanDistanceMeasure(), new TriangularKernelProfile(), 4.0d, 1.0d, 0.5d, true);
        ArrayList<MeanShiftCanopy> newArrayList = Lists.newArrayList();
        int i = 0;
        for (Vector vector : this.raw) {
            int i2 = i;
            i++;
            meanShiftCanopyClusterer.mergeCanopy(new MeanShiftCanopy(vector, i2, this.euclideanDistanceMeasure), newArrayList);
        }
        boolean z = false;
        int i3 = 1;
        while (!z) {
            z = true;
            ArrayList newArrayList2 = Lists.newArrayList();
            for (MeanShiftCanopy meanShiftCanopy : newArrayList) {
                z = meanShiftCanopyClusterer.shiftToMean(meanShiftCanopy) && z;
                meanShiftCanopyClusterer.mergeCanopy(meanShiftCanopy, newArrayList2);
            }
            newArrayList = newArrayList2;
            printCanopies(newArrayList);
            printImage(newArrayList);
            int i4 = i3;
            i3++;
            System.out.println(i4);
        }
    }

    @Test
    public void testClustererReferenceImplementation() {
        List clusterPoints = MeanShiftCanopyClusterer.clusterPoints(Lists.newArrayList(this.raw), this.euclideanDistanceMeasure, this.kernelProfile, 0.5d, 4.0d, 1.0d, 10);
        printCanopies(clusterPoints);
        printImage(clusterPoints);
    }

    @Test
    public void testCanopyMapperEuclidean() throws Exception {
        MeanShiftCanopyClusterer meanShiftCanopyClusterer = new MeanShiftCanopyClusterer(this.euclideanDistanceMeasure, this.kernelProfile, 4.0d, 1.0d, 0.5d, true);
        List<MeanShiftCanopy> initialCanopies = getInitialCanopies();
        ArrayList<MeanShiftCanopy> newArrayList = Lists.newArrayList();
        int i = 0;
        for (Vector vector : this.raw) {
            int i2 = i;
            i++;
            meanShiftCanopyClusterer.mergeCanopy(new MeanShiftCanopy(vector, i2, this.euclideanDistanceMeasure), newArrayList);
        }
        Configuration configuration = new Configuration();
        configuration.set("org.apache.mahout.clustering.canopy.measure", "org.apache.mahout.common.distance.EuclideanDistanceMeasure");
        configuration.set("org.apache.mahout.clustering.canopy.kernelprofile", "org.apache.mahout.common.kernel.TriangularKernelProfile");
        configuration.set("org.apache.mahout.clustering.canopy.t1", "4");
        configuration.set("org.apache.mahout.clustering.canopy.t2", "1");
        configuration.set("org.apache.mahout.clustering.canopy.convergence", "0.5");
        MeanShiftCanopyMapper meanShiftCanopyMapper = new MeanShiftCanopyMapper();
        DummyRecordWriter dummyRecordWriter = new DummyRecordWriter();
        Mapper.Context build = DummyRecordWriter.build(meanShiftCanopyMapper, configuration, dummyRecordWriter);
        meanShiftCanopyMapper.setup(build);
        for (MeanShiftCanopy meanShiftCanopy : initialCanopies) {
            ClusterWritable clusterWritable = new ClusterWritable();
            clusterWritable.setValue(meanShiftCanopy);
            meanShiftCanopyMapper.map(new Text(), clusterWritable, build);
        }
        meanShiftCanopyMapper.cleanup(build);
        assertEquals("Number of map results", 1L, dummyRecordWriter.getData().size());
        List value = dummyRecordWriter.getValue(new Text("0"));
        assertEquals("Number of canopies", newArrayList.size(), value.size());
        HashMap newHashMap = Maps.newHashMap();
        for (MeanShiftCanopy meanShiftCanopy2 : newArrayList) {
            meanShiftCanopyClusterer.shiftToMean(meanShiftCanopy2);
            newHashMap.put(meanShiftCanopy2.getIdentifier(), meanShiftCanopy2);
        }
        HashMap newHashMap2 = Maps.newHashMap();
        Iterator it = value.iterator();
        while (it.hasNext()) {
            MeanShiftCanopy value2 = ((ClusterWritable) it.next()).getValue();
            newHashMap2.put(value2.getIdentifier(), value2);
        }
        Iterator it2 = newHashMap.entrySet().iterator();
        while (it2.hasNext()) {
            MeanShiftCanopy meanShiftCanopy3 = (MeanShiftCanopy) ((Map.Entry) it2.next()).getValue();
            MeanShiftCanopy meanShiftCanopy4 = (MeanShiftCanopy) newHashMap2.get((meanShiftCanopy3.isConverged() ? "MSV-" : "MSC-") + meanShiftCanopy3.getId());
            assertEquals("ids", meanShiftCanopy3.getId(), meanShiftCanopy4.getId());
            assertEquals("centers(" + meanShiftCanopy3.getIdentifier() + ')', meanShiftCanopy3.getCenter().asFormatString(), meanShiftCanopy4.getCenter().asFormatString());
            assertEquals("bound points", meanShiftCanopy3.getBoundPoints().toList().size(), meanShiftCanopy4.getBoundPoints().toList().size());
            assertEquals("num bound points", meanShiftCanopy3.getNumObservations(), meanShiftCanopy4.getNumObservations());
        }
    }

    @Test
    public void testCanopyReducerEuclidean() throws Exception {
        MeanShiftCanopyClusterer meanShiftCanopyClusterer = new MeanShiftCanopyClusterer(this.euclideanDistanceMeasure, this.kernelProfile, 4.0d, 1.0d, 0.5d, true);
        List<MeanShiftCanopy> initialCanopies = getInitialCanopies();
        ArrayList newArrayList = Lists.newArrayList();
        int i = 0;
        for (Vector vector : this.raw) {
            int i2 = i;
            i++;
            meanShiftCanopyClusterer.mergeCanopy(new MeanShiftCanopy(vector, i2, this.euclideanDistanceMeasure), newArrayList);
        }
        Iterator it = newArrayList.iterator();
        while (it.hasNext()) {
            meanShiftCanopyClusterer.shiftToMean((MeanShiftCanopy) it.next());
        }
        ArrayList<MeanShiftCanopy> newArrayList2 = Lists.newArrayList();
        Iterator it2 = newArrayList.iterator();
        while (it2.hasNext()) {
            meanShiftCanopyClusterer.mergeCanopy((MeanShiftCanopy) it2.next(), newArrayList2);
        }
        Iterator it3 = newArrayList2.iterator();
        while (it3.hasNext()) {
            meanShiftCanopyClusterer.shiftToMean((MeanShiftCanopy) it3.next());
        }
        Configuration configuration = new Configuration();
        configuration.set("org.apache.mahout.clustering.canopy.measure", "org.apache.mahout.common.distance.EuclideanDistanceMeasure");
        configuration.set("org.apache.mahout.clustering.canopy.kernelprofile", "org.apache.mahout.common.kernel.TriangularKernelProfile");
        configuration.set("org.apache.mahout.clustering.canopy.t1", "4");
        configuration.set("org.apache.mahout.clustering.canopy.t2", "1");
        configuration.set("org.apache.mahout.clustering.canopy.convergence", "0.5");
        configuration.set("org.apache.mahout.clustering.control.path", "output/control");
        MeanShiftCanopyMapper meanShiftCanopyMapper = new MeanShiftCanopyMapper();
        DummyRecordWriter dummyRecordWriter = new DummyRecordWriter();
        Mapper.Context build = DummyRecordWriter.build(meanShiftCanopyMapper, configuration, dummyRecordWriter);
        meanShiftCanopyMapper.setup(build);
        for (MeanShiftCanopy meanShiftCanopy : initialCanopies) {
            ClusterWritable clusterWritable = new ClusterWritable();
            clusterWritable.setValue(meanShiftCanopy);
            meanShiftCanopyMapper.map(new Text(), clusterWritable, build);
        }
        meanShiftCanopyMapper.cleanup(build);
        assertEquals("Number of map results", 1L, dummyRecordWriter.getData().size());
        MeanShiftCanopyReducer meanShiftCanopyReducer = new MeanShiftCanopyReducer();
        DummyRecordWriter dummyRecordWriter2 = new DummyRecordWriter();
        Reducer.Context build2 = DummyRecordWriter.build(meanShiftCanopyReducer, configuration, dummyRecordWriter2, Text.class, ClusterWritable.class);
        meanShiftCanopyReducer.setup(build2);
        meanShiftCanopyReducer.reduce(new Text("0"), dummyRecordWriter.getValue(new Text("0")), build2);
        meanShiftCanopyReducer.cleanup(build2);
        assertEquals("Number of canopies", newArrayList2.size(), dummyRecordWriter2.getKeys().size());
        HashMap newHashMap = Maps.newHashMap();
        for (MeanShiftCanopy meanShiftCanopy2 : newArrayList2) {
            newHashMap.put(meanShiftCanopy2.getIdentifier(), meanShiftCanopy2);
        }
        for (Map.Entry entry : newHashMap.entrySet()) {
            MeanShiftCanopy meanShiftCanopy3 = (MeanShiftCanopy) entry.getValue();
            List value = dummyRecordWriter2.getValue(new Text((meanShiftCanopy3.isConverged() ? "MSV-" : "MSC-") + meanShiftCanopy3.getId()));
            assertEquals("values", 1L, value.size());
            MeanShiftCanopy value2 = ((ClusterWritable) value.get(0)).getValue();
            assertEquals("ids", meanShiftCanopy3.getId(), value2.getId());
            assertEquals("numPoints", meanShiftCanopy3.getNumObservations(), value2.getNumObservations());
            assertEquals("centers(" + ((String) entry.getKey()) + ')', meanShiftCanopy3.getCenter().asFormatString(), value2.getCenter().asFormatString());
            assertEquals("bound points", meanShiftCanopy3.getBoundPoints().toList().size(), value2.getBoundPoints().toList().size());
            assertEquals("num bound points", meanShiftCanopy3.getNumObservations(), value2.getNumObservations());
        }
    }

    @Test
    public void testCanopyEuclideanMRJob() throws Exception {
        Path testTempDirPath = getTestTempDirPath("testdata");
        Configuration configuration = new Configuration();
        FileSystem fileSystem = FileSystem.get(testTempDirPath.toUri(), configuration);
        ArrayList newArrayList = Lists.newArrayList();
        Random random = new Random(123L);
        Vector[] vectorArr = new Vector[this.raw.length];
        for (int i = 0; i < this.raw.length; i++) {
            vectorArr = this.raw;
        }
        for (int i2 = 0; i2 < vectorArr.length; i2++) {
            vectorArr[i2] = vectorArr[i2 + random.nextInt(this.raw.length - i2)];
        }
        for (Vector vector : vectorArr) {
            newArrayList.add(new VectorWritable(vector));
        }
        ClusteringTestUtils.writePointsToFile(newArrayList, getTestTempFilePath("testdata/file1"), fileSystem, configuration);
        ClusteringTestUtils.writePointsToFile(newArrayList, getTestTempFilePath("testdata/file2"), fileSystem, configuration);
        Path testTempDirPath2 = getTestTempDirPath("output");
        ToolRunner.run(configuration, new MeanShiftCanopyDriver(), new String[]{optKey("input"), getTestTempDirPath("testdata").toString(), optKey("output"), testTempDirPath2.toString(), optKey("distanceMeasure"), EuclideanDistanceMeasure.class.getName(), optKey("kernelProfile"), TriangularKernelProfile.class.getName(), optKey("t1"), "4", optKey("t2"), "1", optKey("clustering"), optKey("maxIter"), "7", optKey("convergenceDelta"), "0.2", optKey("overwrite")});
        FileStatus[] globStatus = FileSystem.get(configuration).globStatus(new Path(testTempDirPath2, "clusters-?-final/part-r-*"));
        assertEquals("Wrong number of matching final parts", 1L, globStatus.length);
        assertEquals("count", 5L, HadoopUtil.countRecords(globStatus[0].getPath(), configuration));
        SequenceFileValueIterator sequenceFileValueIterator = new SequenceFileValueIterator(new Path(testTempDirPath2, "clusters-0/part-m-00000"), true, configuration);
        while (sequenceFileValueIterator.hasNext()) {
            MeanShiftCanopy value = ((ClusterWritable) sequenceFileValueIterator.next()).getValue();
            assertTrue(value.getCenter() instanceof DenseVector);
            assertFalse(value.getBoundPoints().isEmpty());
        }
    }

    @Test
    public void testCanopyEuclideanSeqJob() throws Exception {
        Path testTempDirPath = getTestTempDirPath("testdata");
        Configuration configuration = new Configuration();
        FileSystem fileSystem = FileSystem.get(testTempDirPath.toUri(), configuration);
        ArrayList newArrayList = Lists.newArrayList();
        for (Vector vector : this.raw) {
            newArrayList.add(new VectorWritable(vector));
        }
        ClusteringTestUtils.writePointsToFile(newArrayList, getTestTempFilePath("testdata/file1"), fileSystem, configuration);
        ClusteringTestUtils.writePointsToFile(newArrayList, getTestTempFilePath("testdata/file2"), fileSystem, configuration);
        Path testTempDirPath2 = getTestTempDirPath("output");
        System.out.println("Output Path: " + testTempDirPath2);
        ToolRunner.run(new Configuration(), new MeanShiftCanopyDriver(), new String[]{optKey("input"), getTestTempDirPath("testdata").toString(), optKey("output"), testTempDirPath2.toString(), optKey("distanceMeasure"), EuclideanDistanceMeasure.class.getName(), optKey("kernelProfile"), TriangularKernelProfile.class.getName(), optKey("t1"), "4", optKey("t2"), "1", optKey("clustering"), optKey("maxIter"), "7", optKey("convergenceDelta"), "0.2", optKey("overwrite"), optKey("method"), "sequential"});
        assertEquals("count", 3L, HadoopUtil.countRecords(new Path(testTempDirPath2, "clusters-7-final/part-r-00000"), configuration));
    }

    @Test
    public void testCanopyEuclideanMRJobNoClustering() throws Exception {
        Path testTempDirPath = getTestTempDirPath("testdata");
        Configuration configuration = new Configuration();
        FileSystem fileSystem = FileSystem.get(testTempDirPath.toUri(), configuration);
        ArrayList newArrayList = Lists.newArrayList();
        for (Vector vector : this.raw) {
            newArrayList.add(new VectorWritable(vector));
        }
        ClusteringTestUtils.writePointsToFile(newArrayList, getTestTempFilePath("testdata/file1"), fileSystem, configuration);
        ClusteringTestUtils.writePointsToFile(newArrayList, getTestTempFilePath("testdata/file2"), fileSystem, configuration);
        Path testTempDirPath2 = getTestTempDirPath("output");
        ToolRunner.run(configuration, new MeanShiftCanopyDriver(), new String[]{optKey("input"), getTestTempDirPath("testdata").toString(), optKey("output"), testTempDirPath2.toString(), optKey("distanceMeasure"), EuclideanDistanceMeasure.class.getName(), optKey("kernelProfile"), TriangularKernelProfile.class.getName(), optKey("t1"), "4", optKey("t2"), "1", optKey("maxIter"), "7", optKey("convergenceDelta"), "0.2", optKey("overwrite")});
        Path path = new Path(testTempDirPath2, "clusters-3-final/part-r-00000");
        assertEquals("count", 3L, HadoopUtil.countRecords(path, configuration));
        SequenceFileValueIterator sequenceFileValueIterator = new SequenceFileValueIterator(path, true, configuration);
        while (sequenceFileValueIterator.hasNext()) {
            assertTrue(((ClusterWritable) sequenceFileValueIterator.next()).getValue().getCenter() instanceof DenseVector);
            assertEquals(1L, r0.getBoundPoints().size());
        }
    }

    @Test
    public void testCanopyEuclideanSeqJobNoClustering() throws Exception {
        Path testTempDirPath = getTestTempDirPath("testdata");
        Configuration configuration = new Configuration();
        FileSystem fileSystem = FileSystem.get(testTempDirPath.toUri(), configuration);
        ArrayList newArrayList = Lists.newArrayList();
        for (Vector vector : this.raw) {
            newArrayList.add(new VectorWritable(vector));
        }
        ClusteringTestUtils.writePointsToFile(newArrayList, getTestTempFilePath("testdata/file1"), fileSystem, configuration);
        ClusteringTestUtils.writePointsToFile(newArrayList, getTestTempFilePath("testdata/file2"), fileSystem, configuration);
        Path testTempDirPath2 = getTestTempDirPath("output");
        System.out.println("Output Path: " + testTempDirPath2);
        ToolRunner.run(new Configuration(), new MeanShiftCanopyDriver(), new String[]{optKey("input"), getTestTempDirPath("testdata").toString(), optKey("output"), testTempDirPath2.toString(), optKey("distanceMeasure"), EuclideanDistanceMeasure.class.getName(), optKey("kernelProfile"), TriangularKernelProfile.class.getName(), optKey("t1"), "4", optKey("t2"), "1", optKey("maxIter"), "7", optKey("convergenceDelta"), "0.2", optKey("overwrite"), optKey("method"), "sequential"});
        Path path = new Path(testTempDirPath2, "clusters-7-final/part-r-00000");
        assertEquals("count", 3L, HadoopUtil.countRecords(path, configuration));
        SequenceFileValueIterator sequenceFileValueIterator = new SequenceFileValueIterator(path, true, configuration);
        while (sequenceFileValueIterator.hasNext()) {
            assertEquals(1L, ((ClusterWritable) sequenceFileValueIterator.next()).getValue().getBoundPoints().size());
        }
    }
}
