/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.math.neighborhood;

import java.util.BitSet;
import java.util.List;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.WeightedVector;
import org.apache.mahout.math.function.DoubleFunction;
import org.apache.mahout.math.neighborhood.BruteSearch;
import org.apache.mahout.math.neighborhood.HashedVector;
import org.apache.mahout.math.neighborhood.LocalitySensitiveHashSearch;
import org.apache.mahout.math.random.Normal;
import org.apache.mahout.math.random.WeightedThing;
import org.apache.mahout.math.stats.OnlineSummarizer;
import org.junit.Assert;
import org.junit.Test;

public class LocalitySensitiveHashSearchTest {
    @Test
    public void testNormal() {
        DenseMatrix testData = new DenseMatrix(100000, 10);
        Normal gen = new Normal();
        testData.assign((DoubleFunction)gen);
        EuclideanDistanceMeasure distance = new EuclideanDistanceMeasure();
        BruteSearch ref = new BruteSearch((DistanceMeasure)distance);
        ref.addAllMatrixSlicesAsWeightedVectors((Iterable)testData);
        LocalitySensitiveHashSearch cut = new LocalitySensitiveHashSearch((DistanceMeasure)distance, 10);
        cut.addAllMatrixSlicesAsWeightedVectors((Iterable)testData);
        cut.setSearchSize(200);
        cut.resetEvaluationCount();
        System.out.printf("speedup,q1,q2,q3\n", new Object[0]);
        for (int i = 0; i < 12; ++i) {
            double strategy = ((double)i - 1.0) / 10.0;
            cut.setRaiseHashLimitStrategy(strategy);
            OnlineSummarizer t1 = LocalitySensitiveHashSearchTest.evaluateStrategy((Matrix)testData, ref, cut);
            int evals = cut.resetEvaluationCount();
            double speedup = 1.0E7 / (double)evals;
            System.out.printf("%.1f,%.2f,%.2f,%.2f\n", speedup, t1.getQuartile(1), t1.getQuartile(2), t1.getQuartile(3));
            Assert.assertTrue((t1.getQuartile(2) > 0.45 ? 1 : 0) != 0);
            Assert.assertTrue((speedup > 4.0 || t1.getQuartile(2) > 0.9 ? 1 : 0) != 0);
            Assert.assertTrue((speedup > 15.0 || t1.getQuartile(2) > 0.8 ? 1 : 0) != 0);
        }
    }

    private static OnlineSummarizer evaluateStrategy(Matrix testData, BruteSearch ref, LocalitySensitiveHashSearch cut) {
        OnlineSummarizer t1 = new OnlineSummarizer();
        for (int i = 0; i < 100; ++i) {
            Vector q = testData.viewRow(i);
            List v1 = cut.search(q, 150);
            BitSet b1 = new BitSet();
            for (WeightedThing v : v1) {
                b1.set(((WeightedVector)v.getValue()).getIndex());
            }
            List v2 = ref.search(q, 100);
            BitSet b2 = new BitSet();
            for (WeightedThing v : v2) {
                b2.set(((WeightedVector)v.getValue()).getIndex());
            }
            b1.and(b2);
            t1.add((double)b1.cardinality());
        }
        return t1;
    }

    @Test
    public void testDotCorrelation() {
        int i;
        Normal gen = new Normal();
        DenseMatrix projection = new DenseMatrix(64, 10);
        projection.assign((DoubleFunction)gen);
        DenseVector query = new DenseVector(10);
        query.assign((DoubleFunction)gen);
        long qhash = HashedVector.computeHash64((Vector)query, (Matrix)projection);
        int[] count = new int[65];
        DenseVector v = new DenseVector(10);
        for (i = 0; i < 500000; ++i) {
            int bitDot;
            v.assign((DoubleFunction)gen);
            long hash = HashedVector.computeHash64((Vector)v, (Matrix)projection);
            int n = bitDot = Long.bitCount(qhash ^ hash);
            count[n] = count[n] + 1;
            if (count[bitDot] >= 200) continue;
            System.out.printf("%d, %.3f\n", bitDot, v.dot((Vector)query) / Math.sqrt(v.getLengthSquared() * query.getLengthSquared()));
        }
        for (i = 0; i < 65; ++i) {
            System.out.printf("%d, %d\n", i, count[i]);
        }
    }
}

