package org.apache.mahout.math.neighborhood;

import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.collect.Iterables;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.distance.CosineDistanceMeasure;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.random.WeightedThing;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

@RunWith(Parameterized.class)
/* loaded from: input_file:org/apache/mahout/math/neighborhood/SearchQualityTest.class */
public class SearchQualityTest {
    private static final int NUM_DATA_POINTS = 16384;
    private static final int NUM_QUERIES = 1024;
    private static final int NUM_DIMENSIONS = 40;
    private static final int NUM_RESULTS = 2;
    private final Searcher searcher;
    private final Matrix dataPoints;
    private final Matrix queries;
    private Pair<List<List<WeightedThing<Vector>>>, Long> reference;
    private Pair<List<WeightedThing<Vector>>, Long> referenceSearchFirst;

    /* loaded from: input_file:org/apache/mahout/math/neighborhood/SearchQualityTest$StripWeight.class */
    static class StripWeight implements Function<WeightedThing<Vector>, Vector> {
        StripWeight() {
        }

        public Vector apply(WeightedThing<Vector> weightedThing) {
            Preconditions.checkArgument(weightedThing != null, "input is null");
            return (Vector) weightedThing.getValue();
        }
    }

    @Parameterized.Parameters
    public static List<Object[]> generateData() {
        RandomUtils.useTestSeed();
        Matrix lumpyRandomData = LumpyData.lumpyRandomData(NUM_DATA_POINTS, NUM_DIMENSIONS);
        Matrix lumpyRandomData2 = LumpyData.lumpyRandomData(NUM_QUERIES, NUM_DIMENSIONS);
        CosineDistanceMeasure cosineDistanceMeasure = new CosineDistanceMeasure();
        BruteSearch bruteSearch = new BruteSearch(cosineDistanceMeasure);
        bruteSearch.addAll(lumpyRandomData);
        Pair<List<List<WeightedThing<Vector>>>, Long> resultsAndRuntime = getResultsAndRuntime(bruteSearch, lumpyRandomData2);
        Pair<List<WeightedThing<Vector>>, Long> resultsAndRuntimeSearchFirst = getResultsAndRuntimeSearchFirst(bruteSearch, lumpyRandomData2);
        System.out.printf("BruteSearch: avg_time(1 query) %f[s]\n", Double.valueOf(((Long) resultsAndRuntime.getSecond()).longValue() / (lumpyRandomData2.numRows() * 1.0d)));
        return Arrays.asList(new Object[]{new ProjectionSearch(cosineDistanceMeasure, 3, 10), lumpyRandomData, lumpyRandomData2, resultsAndRuntime, resultsAndRuntimeSearchFirst}, new Object[]{new FastProjectionSearch(cosineDistanceMeasure, 3, 10), lumpyRandomData, lumpyRandomData2, resultsAndRuntime, resultsAndRuntimeSearchFirst}, new Object[]{new ProjectionSearch(cosineDistanceMeasure, 5, 5), lumpyRandomData, lumpyRandomData2, resultsAndRuntime, resultsAndRuntimeSearchFirst}, new Object[]{new FastProjectionSearch(cosineDistanceMeasure, 5, 5), lumpyRandomData, lumpyRandomData2, resultsAndRuntime, resultsAndRuntimeSearchFirst});
    }

    public SearchQualityTest(Searcher searcher, Matrix matrix, Matrix matrix2, Pair<List<List<WeightedThing<Vector>>>, Long> pair, Pair<List<WeightedThing<Vector>>, Long> pair2) {
        this.searcher = searcher;
        this.dataPoints = matrix;
        this.queries = matrix2;
        this.reference = pair;
        this.referenceSearchFirst = pair2;
    }

    @Test
    public void testOverlapAndRuntimeSearchFirst() {
        this.searcher.clear();
        this.searcher.addAll(this.dataPoints);
        Pair<List<WeightedThing<Vector>>, Long> resultsAndRuntimeSearchFirst = getResultsAndRuntimeSearchFirst(this.searcher, this.queries);
        int i = 0;
        for (int i2 = 0; i2 < this.queries.numRows(); i2++) {
            if (((Vector) ((WeightedThing) ((List) this.referenceSearchFirst.getFirst()).get(i2)).getValue()).equals(((WeightedThing) ((List) resultsAndRuntimeSearchFirst.getFirst()).get(i2)).getValue())) {
                i++;
            }
        }
        double longValue = ((Long) this.reference.getSecond()).longValue() / (this.queries.numRows() * 1.0d);
        double longValue2 = ((Long) resultsAndRuntimeSearchFirst.getSecond()).longValue() / (this.queries.numRows() * 1.0d);
        System.out.printf("%s: first matches %d [%d]; avg_time(1 query) %f(s) [%f]\n", this.searcher.getClass().getName(), Integer.valueOf(i), Integer.valueOf(this.queries.numRows()), Double.valueOf(longValue2), Double.valueOf(longValue));
        Assert.assertEquals("Closest vector returned doesn't match", this.queries.numRows(), i);
        Assert.assertTrue("Searcher " + this.searcher.getClass().getName() + " slower than brute", longValue > longValue2);
    }

    @Test
    public void testOverlapAndRuntime() {
        this.searcher.clear();
        this.searcher.addAll(this.dataPoints);
        Pair<List<List<WeightedThing<Vector>>>, Long> resultsAndRuntime = getResultsAndRuntime(this.searcher, this.queries);
        int i = 0;
        int i2 = 0;
        StripWeight stripWeight = new StripWeight();
        for (int i3 = 0; i3 < this.queries.numRows(); i3++) {
            List list = (List) ((List) this.reference.getFirst()).get(i3);
            List list2 = (List) ((List) resultsAndRuntime.getFirst()).get(i3);
            if (((Vector) ((WeightedThing) list.get(0)).getValue()).equals(((WeightedThing) list2.get(0)).getValue())) {
                i++;
            }
            for (Vector vector : Iterables.transform(list, stripWeight)) {
                Iterator it = Iterables.transform(list2, stripWeight).iterator();
                while (it.hasNext()) {
                    if (vector.equals((Vector) it.next())) {
                        i2++;
                    }
                }
            }
        }
        double longValue = ((Long) this.reference.getSecond()).longValue() / (this.queries.numRows() * 1.0d);
        double longValue2 = ((Long) resultsAndRuntime.getSecond()).longValue() / (this.queries.numRows() * 1.0d);
        System.out.printf("%s: first matches %d [%d]; total matches %d [%d]; avg_time(1 query) %f(s) [%f]\n", this.searcher.getClass().getName(), Integer.valueOf(i), Integer.valueOf(this.queries.numRows()), Integer.valueOf(i2), Integer.valueOf(this.queries.numRows() * NUM_RESULTS), Double.valueOf(longValue2), Double.valueOf(longValue));
        Assert.assertEquals("Closest vector returned doesn't match", this.queries.numRows(), i);
        Assert.assertTrue("Searcher " + this.searcher.getClass().getName() + " slower than brute", longValue > longValue2);
    }

    public static Pair<List<List<WeightedThing<Vector>>>, Long> getResultsAndRuntime(Searcher searcher, Iterable<? extends Vector> iterable) {
        return new Pair<>(searcher.search(iterable, NUM_RESULTS), Long.valueOf(System.currentTimeMillis() - System.currentTimeMillis()));
    }

    public static Pair<List<WeightedThing<Vector>>, Long> getResultsAndRuntimeSearchFirst(Searcher searcher, Iterable<? extends Vector> iterable) {
        return new Pair<>(searcher.searchFirst(iterable, false), Long.valueOf(System.currentTimeMillis() - System.currentTimeMillis()));
    }
}
