package org.apache.mahout.math.hadoop.similarity;

import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.clustering.ClusteringTestUtils;
import org.apache.mahout.common.DummyOutputCollector;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.StringTuple;
import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.NamedVector;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.VectorWritable;
import org.easymock.EasyMock;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/math/hadoop/similarity/TestVectorDistanceSimilarityJob.class */
public class TestVectorDistanceSimilarityJob extends MahoutTestCase {
    private FileSystem fs;
    private static final double[][] REFERENCE = {new double[]{1.0d, 1.0d}, new double[]{2.0d, 1.0d}, new double[]{1.0d, 2.0d}, new double[]{2.0d, 2.0d}, new double[]{3.0d, 3.0d}, new double[]{4.0d, 4.0d}, new double[]{5.0d, 4.0d}, new double[]{4.0d, 5.0d}, new double[]{5.0d, 5.0d}};
    private static final double[][] SEEDS = {new double[]{1.0d, 1.0d}, new double[]{10.0d, 10.0d}};

    @Override // org.apache.mahout.common.MahoutTestCase
    @Before
    public void setUp() throws Exception {
        super.setUp();
        this.fs = FileSystem.get(new Configuration());
    }

    @Test
    public void testVectorDistanceMapper() throws Exception {
        Mapper.Context context = (Mapper.Context) EasyMock.createMock(Mapper.Context.class);
        StringTuple stringTuple = new StringTuple();
        stringTuple.add("foo");
        stringTuple.add("123");
        context.write(stringTuple, new DoubleWritable(Math.sqrt(2.0d)));
        StringTuple stringTuple2 = new StringTuple();
        stringTuple2.add("foo2");
        stringTuple2.add("123");
        context.write(stringTuple2, new DoubleWritable(1.0d));
        EasyMock.replay(new Object[]{context});
        RandomAccessSparseVector randomAccessSparseVector = new RandomAccessSparseVector(2);
        randomAccessSparseVector.set(0, 2.0d);
        randomAccessSparseVector.set(1, 2.0d);
        VectorDistanceMapper vectorDistanceMapper = new VectorDistanceMapper();
        setField(vectorDistanceMapper, "measure", new EuclideanDistanceMeasure());
        ArrayList arrayList = new ArrayList();
        RandomAccessSparseVector randomAccessSparseVector2 = new RandomAccessSparseVector(2);
        randomAccessSparseVector2.set(0, 1.0d);
        randomAccessSparseVector2.set(1, 1.0d);
        RandomAccessSparseVector randomAccessSparseVector3 = new RandomAccessSparseVector(2);
        randomAccessSparseVector3.set(0, 2.0d);
        randomAccessSparseVector3.set(1, 1.0d);
        arrayList.add(new NamedVector(randomAccessSparseVector2, "foo"));
        arrayList.add(new NamedVector(randomAccessSparseVector3, "foo2"));
        setField(vectorDistanceMapper, "seedVectors", arrayList);
        vectorDistanceMapper.map(new IntWritable(123), new VectorWritable(randomAccessSparseVector), context);
        EasyMock.verify(new Object[]{context});
    }

    @Test
    public void testVectorDistanceInvertedMapper() throws Exception {
        Mapper.Context context = (Mapper.Context) EasyMock.createMock(Mapper.Context.class);
        context.write(new Text("other"), new VectorWritable(new DenseVector(new double[]{Math.sqrt(2.0d), 1.0d})));
        EasyMock.replay(new Object[]{context});
        NamedVector namedVector = new NamedVector(new RandomAccessSparseVector(2), "other");
        namedVector.set(0, 2.0d);
        namedVector.set(1, 2.0d);
        VectorDistanceInvertedMapper vectorDistanceInvertedMapper = new VectorDistanceInvertedMapper();
        setField(vectorDistanceInvertedMapper, "measure", new EuclideanDistanceMeasure());
        ArrayList arrayList = new ArrayList();
        RandomAccessSparseVector randomAccessSparseVector = new RandomAccessSparseVector(2);
        randomAccessSparseVector.set(0, 1.0d);
        randomAccessSparseVector.set(1, 1.0d);
        RandomAccessSparseVector randomAccessSparseVector2 = new RandomAccessSparseVector(2);
        randomAccessSparseVector2.set(0, 2.0d);
        randomAccessSparseVector2.set(1, 1.0d);
        arrayList.add(new NamedVector(randomAccessSparseVector, "foo"));
        arrayList.add(new NamedVector(randomAccessSparseVector2, "foo2"));
        setField(vectorDistanceInvertedMapper, "seedVectors", arrayList);
        vectorDistanceInvertedMapper.map(new IntWritable(123), new VectorWritable(namedVector), context);
        EasyMock.verify(new Object[]{context});
    }

    @Test
    public void testRun() throws Exception {
        Path testTempDirPath = getTestTempDirPath("input");
        Path testTempDirPath2 = getTestTempDirPath("output");
        Path testTempDirPath3 = getTestTempDirPath("seeds");
        List<VectorWritable> pointsWritable = getPointsWritable(REFERENCE);
        List<VectorWritable> pointsWritable2 = getPointsWritable(SEEDS);
        Configuration configuration = new Configuration();
        ClusteringTestUtils.writePointsToFile(pointsWritable, true, new Path(testTempDirPath, "file1"), this.fs, configuration);
        ClusteringTestUtils.writePointsToFile(pointsWritable2, true, new Path(testTempDirPath3, "part-seeds"), this.fs, configuration);
        ToolRunner.run(new Configuration(), new VectorDistanceSimilarityJob(), new String[]{optKey("input"), testTempDirPath.toString(), optKey("seeds"), testTempDirPath3.toString(), optKey("output"), testTempDirPath2.toString(), optKey("distanceMeasure"), EuclideanDistanceMeasure.class.getName()});
        int length = SEEDS.length * REFERENCE.length;
        DummyOutputCollector dummyOutputCollector = new DummyOutputCollector();
        Iterator it = new SequenceFileIterable(new Path(testTempDirPath2, "part-m-00000"), configuration).iterator();
        while (it.hasNext()) {
            Pair pair = (Pair) it.next();
            dummyOutputCollector.collect((DummyOutputCollector) pair.getFirst(), (WritableComparable) pair.getSecond());
        }
        assertEquals(length, dummyOutputCollector.getData().size());
    }

    @Test
    public void testRunInverted() throws Exception {
        Path testTempDirPath = getTestTempDirPath("input");
        Path testTempDirPath2 = getTestTempDirPath("output");
        Path testTempDirPath3 = getTestTempDirPath("seeds");
        List<VectorWritable> pointsWritable = getPointsWritable(REFERENCE);
        List<VectorWritable> pointsWritable2 = getPointsWritable(SEEDS);
        Configuration configuration = new Configuration();
        ClusteringTestUtils.writePointsToFile(pointsWritable, true, new Path(testTempDirPath, "file1"), this.fs, configuration);
        ClusteringTestUtils.writePointsToFile(pointsWritable2, true, new Path(testTempDirPath3, "part-seeds"), this.fs, configuration);
        ToolRunner.run(new Configuration(), new VectorDistanceSimilarityJob(), new String[]{optKey("input"), testTempDirPath.toString(), optKey("seeds"), testTempDirPath3.toString(), optKey("output"), testTempDirPath2.toString(), optKey("distanceMeasure"), EuclideanDistanceMeasure.class.getName(), optKey("outType"), "v"});
        DummyOutputCollector dummyOutputCollector = new DummyOutputCollector();
        Iterator it = new SequenceFileIterable(new Path(testTempDirPath2, "part-m-00000"), configuration).iterator();
        while (it.hasNext()) {
            Pair pair = (Pair) it.next();
            dummyOutputCollector.collect((DummyOutputCollector) pair.getFirst(), (WritableComparable) pair.getSecond());
        }
        assertEquals(REFERENCE.length, dummyOutputCollector.getData().size());
        Iterator it2 = dummyOutputCollector.getData().entrySet().iterator();
        while (it2.hasNext()) {
            assertEquals(SEEDS.length, ((VectorWritable) ((List) ((Map.Entry) it2.next()).getValue()).iterator().next()).get().size());
        }
    }

    public static List<VectorWritable> getPointsWritable(double[][] dArr) {
        ArrayList newArrayList = Lists.newArrayList();
        for (double[] dArr2 : dArr) {
            RandomAccessSparseVector randomAccessSparseVector = new RandomAccessSparseVector(dArr2.length);
            randomAccessSparseVector.assign(dArr2);
            newArrayList.add(new VectorWritable(randomAccessSparseVector));
        }
        return newArrayList;
    }
}
