package org.apache.mahout.math.neighborhood;

import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.MatrixSlice;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.random.MultiNormal;
import org.apache.mahout.math.random.WeightedThing;
import org.hamcrest.Matchers;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

@RunWith(Parameterized.class)
/* loaded from: input_file:org/apache/mahout/math/neighborhood/SearchSanityTest.class */
public class SearchSanityTest extends MahoutTestCase {
    private static final int NUM_DATA_POINTS = 8192;
    private static final int NUM_DIMENSIONS = 20;
    private static final int NUM_PROJECTIONS = 3;
    private static final int SEARCH_SIZE = 30;
    private UpdatableSearcher searcher;
    private Matrix dataPoints;

    protected static Matrix multiNormalRandomData(int i, int i2) {
        DenseMatrix denseMatrix = new DenseMatrix(i, i2);
        MultiNormal multiNormal = new MultiNormal(NUM_DIMENSIONS);
        Iterator it = denseMatrix.iterator();
        while (it.hasNext()) {
            ((MatrixSlice) it.next()).vector().assign(multiNormal.sample());
        }
        return denseMatrix;
    }

    @Parameterized.Parameters
    public static List<Object[]> generateData() {
        RandomUtils.useTestSeed();
        Matrix multiNormalRandomData = multiNormalRandomData(NUM_DATA_POINTS, NUM_DIMENSIONS);
        return Arrays.asList(new Object[]{new ProjectionSearch(new EuclideanDistanceMeasure(), NUM_PROJECTIONS, SEARCH_SIZE), multiNormalRandomData}, new Object[]{new FastProjectionSearch(new EuclideanDistanceMeasure(), NUM_PROJECTIONS, SEARCH_SIZE), multiNormalRandomData}, new Object[]{new LocalitySensitiveHashSearch(new EuclideanDistanceMeasure(), SEARCH_SIZE), multiNormalRandomData});
    }

    public SearchSanityTest(UpdatableSearcher updatableSearcher, Matrix matrix) {
        this.searcher = updatableSearcher;
        this.dataPoints = matrix;
    }

    @Test
    public void testExactMatch() {
        this.searcher.clear();
        Matrix matrix = this.dataPoints;
        Iterable limit = Iterables.limit(matrix, 300);
        ArrayList<MatrixSlice> newArrayList = Lists.newArrayList(Iterables.limit(limit, 100));
        this.searcher.addAllMatrixSlices(limit);
        assertEquals(300L, this.searcher.size());
        Vector vector = ((MatrixSlice) Iterables.get(matrix, 0)).vector();
        assertEquals(0.0d, ((Vector) ((WeightedThing) this.searcher.search(vector, 2).get(0)).getValue()).minus(vector).norm(1.0d), 1.0E-8d);
        this.searcher.addAllMatrixSlices(Iterables.limit(Iterables.skip(matrix, 300), 10));
        assertEquals(310L, this.searcher.size());
        Vector vector2 = ((MatrixSlice) Iterables.get(matrix, 302)).vector();
        assertEquals(0.0d, ((Vector) ((WeightedThing) this.searcher.search(vector2, 2).get(0)).getValue()).minus(vector2).norm(1.0d), 1.0E-8d);
        this.searcher.addAllMatrixSlices(Iterables.skip(matrix, 310));
        assertEquals(this.dataPoints.numRows(), this.searcher.size());
        for (MatrixSlice matrixSlice : newArrayList) {
            List search = this.searcher.search(matrixSlice.vector(), 2);
            assertEquals("Distance has to be about zero", 0.0d, ((WeightedThing) search.get(0)).getWeight(), 1.0E-6d);
            assertEquals("Answer must be substantially the same as query", 0.0d, ((Vector) ((WeightedThing) search.get(0)).getValue()).minus(matrixSlice.vector()).norm(1.0d), 1.0E-8d);
            assertTrue("Wrong answer must have non-zero distance", ((WeightedThing) search.get(1)).getWeight() > ((WeightedThing) search.get(0)).getWeight());
        }
    }

    @Test
    public void testNearMatch() {
        this.searcher.clear();
        ArrayList newArrayList = Lists.newArrayList(Iterables.limit(this.dataPoints, 100));
        this.searcher.addAllMatrixSlicesAsWeightedVectors(this.dataPoints);
        MultiNormal multiNormal = new MultiNormal(0.01d, new DenseVector(NUM_DIMENSIONS));
        Iterator it = newArrayList.iterator();
        while (it.hasNext()) {
            Vector vector = ((MatrixSlice) it.next()).vector();
            Vector sample = multiNormal.sample();
            List search = this.searcher.search(vector, 2);
            Vector plus = vector.plus(sample);
            assertEquals("Distance has to be small", sample.norm(2.0d), ((WeightedThing) search.get(0)).getWeight(), 0.1d);
            assertEquals("Answer must be substantially the same as query", sample.norm(2.0d), ((Vector) ((WeightedThing) search.get(0)).getValue()).minus(plus).norm(2.0d), 0.1d);
            assertTrue("Wrong answer must be further away", ((WeightedThing) search.get(1)).getWeight() > ((WeightedThing) search.get(0)).getWeight());
        }
    }

    @Test
    public void testOrdering() {
        this.searcher.clear();
        DenseMatrix denseMatrix = new DenseMatrix(100, NUM_DIMENSIONS);
        MultiNormal multiNormal = new MultiNormal(NUM_DIMENSIONS);
        for (int i = 0; i < 100; i++) {
            denseMatrix.viewRow(i).assign(multiNormal.sample());
        }
        this.searcher.addAllMatrixSlices(this.dataPoints);
        Iterator it = denseMatrix.iterator();
        while (it.hasNext()) {
            double d = 0.0d;
            for (WeightedThing weightedThing : this.searcher.search(((MatrixSlice) it.next()).vector(), 200)) {
                assertTrue("Scores must be monotonic increasing", weightedThing.getWeight() >= d);
                d = weightedThing.getWeight();
            }
        }
    }

    @Test
    public void testRemoval() {
        this.searcher.clear();
        this.searcher.addAllMatrixSlices(this.dataPoints);
        if (!(this.searcher instanceof UpdatableSearcher)) {
            try {
                this.searcher.remove((Vector) Lists.newArrayList(Iterables.limit(this.searcher, 2)).get(0), 1.0E-7d);
                fail("Shouldn't be able to delete from " + this.searcher.getClass().getName());
                return;
            } catch (UnsupportedOperationException e) {
                return;
            }
        }
        ArrayList newArrayList = Lists.newArrayList(Iterables.limit(this.searcher, 2));
        int size = this.searcher.size();
        List search = this.searcher.search((Vector) newArrayList.get(0), 2);
        this.searcher.remove((Vector) newArrayList.get(0), 1.0E-7d);
        assertEquals(size - 1, this.searcher.size());
        List search2 = this.searcher.search((Vector) newArrayList.get(0), 1);
        assertTrue("Vector should be gone", ((WeightedThing) search2.get(0)).getWeight() > 0.0d);
        assertEquals("Previous second neighbor should be first", 0.0d, ((Vector) ((WeightedThing) search2.get(0)).getValue()).minus((Vector) ((WeightedThing) search.get(1)).getValue()).norm(1.0d), 1.0E-8d);
        this.searcher.remove((Vector) newArrayList.get(1), 1.0E-7d);
        assertEquals(size - 2, this.searcher.size());
        assertTrue("Vector should be gone", ((WeightedThing) this.searcher.search((Vector) newArrayList.get(1), 1).get(0)).getWeight() > 0.0d);
        Iterator it = this.searcher.iterator();
        while (it.hasNext()) {
            Vector vector = (Vector) it.next();
            assertTrue(((Vector) newArrayList.get(0)).minus(vector).norm(1.0d) > 1.0E-6d);
            assertTrue(((Vector) newArrayList.get(1)).minus(vector).norm(1.0d) > 1.0E-6d);
        }
    }

    @Test
    public void testSearchFirst() {
        this.searcher.clear();
        this.searcher.addAll(this.dataPoints);
        for (MatrixSlice matrixSlice : this.dataPoints) {
            WeightedThing searchFirst = this.searcher.searchFirst(matrixSlice, false);
            WeightedThing searchFirst2 = this.searcher.searchFirst(matrixSlice, true);
            List search = this.searcher.search(matrixSlice, 2);
            assertEquals("First isn't self", 0.0d, searchFirst.getWeight(), 0.0d);
            assertEquals("First isn't self", matrixSlice, searchFirst.getValue());
            assertEquals("First doesn't match", searchFirst, search.get(0));
            assertEquals("Second doesn't match", searchFirst2, search.get(1));
        }
    }

    @Test
    public void testSearchLimiting() {
        this.searcher.clear();
        this.searcher.addAll(this.dataPoints);
        Iterator it = this.dataPoints.iterator();
        while (it.hasNext()) {
            assertThat("Search limit isn't respected", Integer.valueOf(this.searcher.search((MatrixSlice) it.next(), 2).size()), Matchers.is(Matchers.lessThanOrEqualTo(2)));
        }
    }

    @Test
    public void testRemove() {
        this.searcher.clear();
        for (int i = 0; i < this.dataPoints.rowSize(); i++) {
            Vector viewRow = this.dataPoints.viewRow(i);
            this.searcher.add(viewRow);
            if (i % 2 == 0) {
                assertTrue("Failed to find self [search]", ((WeightedThing) this.searcher.search(viewRow, 1).get(0)).getWeight() < 1.0E-6d);
                assertTrue("Failed to find self [searchFirst]", this.searcher.searchFirst(viewRow, false).getWeight() < 1.0E-6d);
                assertTrue("Failed to remove self", this.searcher.remove(viewRow, 1.0E-6d));
            }
        }
    }
}
