/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.math.neighborhood;

import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.collect.Iterables;
import java.util.Arrays;
import java.util.List;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.distance.CosineDistanceMeasure;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.neighborhood.BruteSearch;
import org.apache.mahout.math.neighborhood.FastProjectionSearch;
import org.apache.mahout.math.neighborhood.LumpyData;
import org.apache.mahout.math.neighborhood.ProjectionSearch;
import org.apache.mahout.math.neighborhood.Searcher;
import org.apache.mahout.math.random.WeightedThing;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

@RunWith(value=Parameterized.class)
public class SearchQualityTest {
    private static final int NUM_DATA_POINTS = 16384;
    private static final int NUM_QUERIES = 1024;
    private static final int NUM_DIMENSIONS = 40;
    private static final int NUM_RESULTS = 2;
    private final Searcher searcher;
    private final Matrix dataPoints;
    private final Matrix queries;
    private Pair<List<List<WeightedThing<Vector>>>, Long> reference;
    private Pair<List<WeightedThing<Vector>>, Long> referenceSearchFirst;

    @Parameterized.Parameters
    public static List<Object[]> generateData() {
        RandomUtils.useTestSeed();
        Matrix dataPoints = LumpyData.lumpyRandomData(16384, 40);
        Matrix queries = LumpyData.lumpyRandomData(1024, 40);
        CosineDistanceMeasure distanceMeasure = new CosineDistanceMeasure();
        BruteSearch bruteSearcher = new BruteSearch((DistanceMeasure)distanceMeasure);
        bruteSearcher.addAll((Iterable)dataPoints);
        Pair<List<List<WeightedThing<Vector>>>, Long> reference = SearchQualityTest.getResultsAndRuntime((Searcher)bruteSearcher, (Iterable<? extends Vector>)queries);
        Pair<List<WeightedThing<Vector>>, Long> referenceSearchFirst = SearchQualityTest.getResultsAndRuntimeSearchFirst((Searcher)bruteSearcher, (Iterable<? extends Vector>)queries);
        double bruteSearchAvgTime = (double)((Long)reference.getSecond()).longValue() / ((double)queries.numRows() * 1.0);
        System.out.printf("BruteSearch: avg_time(1 query) %f[s]\n", bruteSearchAvgTime);
        return Arrays.asList({new ProjectionSearch((DistanceMeasure)distanceMeasure, 3, 10), dataPoints, queries, reference, referenceSearchFirst}, {new FastProjectionSearch((DistanceMeasure)distanceMeasure, 3, 10), dataPoints, queries, reference, referenceSearchFirst}, {new ProjectionSearch((DistanceMeasure)distanceMeasure, 5, 5), dataPoints, queries, reference, referenceSearchFirst}, {new FastProjectionSearch((DistanceMeasure)distanceMeasure, 5, 5), dataPoints, queries, reference, referenceSearchFirst});
    }

    public SearchQualityTest(Searcher searcher, Matrix dataPoints, Matrix queries, Pair<List<List<WeightedThing<Vector>>>, Long> reference, Pair<List<WeightedThing<Vector>>, Long> referenceSearchFirst) {
        this.searcher = searcher;
        this.dataPoints = dataPoints;
        this.queries = queries;
        this.reference = reference;
        this.referenceSearchFirst = referenceSearchFirst;
    }

    @Test
    public void testOverlapAndRuntimeSearchFirst() {
        this.searcher.clear();
        this.searcher.addAll((Iterable)this.dataPoints);
        Pair<List<WeightedThing<Vector>>, Long> results = SearchQualityTest.getResultsAndRuntimeSearchFirst(this.searcher, (Iterable<? extends Vector>)this.queries);
        int numFirstMatches = 0;
        for (int i = 0; i < this.queries.numRows(); ++i) {
            WeightedThing referenceVector = (WeightedThing)((List)this.referenceSearchFirst.getFirst()).get(i);
            WeightedThing resultVector = (WeightedThing)((List)results.getFirst()).get(i);
            if (!((Vector)referenceVector.getValue()).equals(resultVector.getValue())) continue;
            ++numFirstMatches;
        }
        double bruteSearchAvgTime = (double)((Long)this.reference.getSecond()).longValue() / ((double)this.queries.numRows() * 1.0);
        double searcherAvgTime = (double)((Long)results.getSecond()).longValue() / ((double)this.queries.numRows() * 1.0);
        System.out.printf("%s: first matches %d [%d]; avg_time(1 query) %f(s) [%f]\n", this.searcher.getClass().getName(), numFirstMatches, this.queries.numRows(), searcherAvgTime, bruteSearchAvgTime);
        Assert.assertEquals((String)"Closest vector returned doesn't match", (long)this.queries.numRows(), (long)numFirstMatches);
        Assert.assertTrue((String)("Searcher " + this.searcher.getClass().getName() + " slower than brute"), (bruteSearchAvgTime > searcherAvgTime ? 1 : 0) != 0);
    }

    @Test
    public void testOverlapAndRuntime() {
        this.searcher.clear();
        this.searcher.addAll((Iterable)this.dataPoints);
        Pair<List<List<WeightedThing<Vector>>>, Long> results = SearchQualityTest.getResultsAndRuntime(this.searcher, (Iterable<? extends Vector>)this.queries);
        int numFirstMatches = 0;
        int numMatches = 0;
        StripWeight stripWeight = new StripWeight();
        for (int i = 0; i < this.queries.numRows(); ++i) {
            List referenceVectors = (List)((List)this.reference.getFirst()).get(i);
            List resultVectors = (List)((List)results.getFirst()).get(i);
            if (((Vector)((WeightedThing)referenceVectors.get(0)).getValue()).equals(((WeightedThing)resultVectors.get(0)).getValue())) {
                ++numFirstMatches;
            }
            for (Vector v : Iterables.transform((Iterable)referenceVectors, (Function)stripWeight)) {
                for (Vector w : Iterables.transform((Iterable)resultVectors, (Function)stripWeight)) {
                    if (!v.equals(w)) continue;
                    ++numMatches;
                }
            }
        }
        double bruteSearchAvgTime = (double)((Long)this.reference.getSecond()).longValue() / ((double)this.queries.numRows() * 1.0);
        double searcherAvgTime = (double)((Long)results.getSecond()).longValue() / ((double)this.queries.numRows() * 1.0);
        System.out.printf("%s: first matches %d [%d]; total matches %d [%d]; avg_time(1 query) %f(s) [%f]\n", this.searcher.getClass().getName(), numFirstMatches, this.queries.numRows(), numMatches, this.queries.numRows() * 2, searcherAvgTime, bruteSearchAvgTime);
        Assert.assertEquals((String)"Closest vector returned doesn't match", (long)this.queries.numRows(), (long)numFirstMatches);
        Assert.assertTrue((String)("Searcher " + this.searcher.getClass().getName() + " slower than brute"), (bruteSearchAvgTime > searcherAvgTime ? 1 : 0) != 0);
    }

    public static Pair<List<List<WeightedThing<Vector>>>, Long> getResultsAndRuntime(Searcher searcher, Iterable<? extends Vector> queries) {
        long start = System.currentTimeMillis();
        List results = searcher.search(queries, 2);
        long end = System.currentTimeMillis();
        return new Pair((Object)results, (Object)(end - start));
    }

    public static Pair<List<WeightedThing<Vector>>, Long> getResultsAndRuntimeSearchFirst(Searcher searcher, Iterable<? extends Vector> queries) {
        long start = System.currentTimeMillis();
        List results = searcher.searchFirst(queries, false);
        long end = System.currentTimeMillis();
        return new Pair((Object)results, (Object)(end - start));
    }

    static class StripWeight
    implements Function<WeightedThing<Vector>, Vector> {
        StripWeight() {
        }

        public Vector apply(WeightedThing<Vector> input) {
            Preconditions.checkArgument((input != null ? 1 : 0) != 0, (Object)"input is null");
            return (Vector)input.getValue();
        }
    }
}

