package org.apache.mahout.math.neighborhood;

import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.Iterators;
import com.google.common.collect.Lists;
import com.google.common.collect.Multiset;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import org.apache.lucene.util.PriorityQueue;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.hadoop.similarity.cooccurrence.measures.VectorSimilarityMeasure;
import org.apache.mahout.math.random.RandomProjector;
import org.apache.mahout.math.random.WeightedThing;
import org.apache.mahout.math.stats.OnlineSummarizer;

/* loaded from: input_file:org/apache/mahout/math/neighborhood/LocalitySensitiveHashSearch.class */
public class LocalitySensitiveHashSearch extends UpdatableSearcher {
    private static final int BITS = 64;
    private static final long BIT_MASK = -1;
    private static final int MAX_HASH_LIMIT = 32;
    private static final int MIN_DISTRIBUTION_COUNT = 10;
    private final Multiset<HashedVector> trainingVectors;
    private Matrix projection;
    private int searchSize;
    private double hashLimitStrategy;
    private int distanceEvaluations;
    private boolean initialized;

    public LocalitySensitiveHashSearch(DistanceMeasure distanceMeasure, int i) {
        super(distanceMeasure);
        this.trainingVectors = HashMultiset.create();
        this.hashLimitStrategy = 0.9d;
        this.distanceEvaluations = 0;
        this.initialized = false;
        this.searchSize = i;
        this.projection = null;
    }

    private void initialize(int i) {
        if (this.initialized) {
            return;
        }
        this.initialized = true;
        this.projection = RandomProjector.generateBasisNormal(BITS, i);
    }

    private PriorityQueue<WeightedThing<Vector>> searchInternal(Vector vector) {
        long computeHash64 = HashedVector.computeHash64(vector, this.projection);
        PriorityQueue<WeightedThing<Vector>> candidateQueue = Searcher.getCandidateQueue(getSearchSize());
        OnlineSummarizer[] onlineSummarizerArr = new OnlineSummarizer[65];
        for (int i = 0; i < 65; i++) {
            onlineSummarizerArr[i] = new OnlineSummarizer();
        }
        this.distanceEvaluations = 0;
        int[] iArr = new int[65];
        int i2 = BITS;
        int i3 = 0;
        double d = Double.POSITIVE_INFINITY;
        Iterator it = this.trainingVectors.iterator();
        while (it.hasNext()) {
            Vector vector2 = (HashedVector) it.next();
            int hammingDistance = vector2.hammingDistance(computeHash64);
            if (hammingDistance <= i2) {
                this.distanceEvaluations++;
                double distance = this.distanceMeasure.distance(vector, vector2);
                onlineSummarizerArr[hammingDistance].add(distance);
                if (distance < d) {
                    candidateQueue.insertWithOverflow(new WeightedThing(vector2, distance));
                    if (candidateQueue.size() == this.searchSize) {
                        d = ((WeightedThing) candidateQueue.top()).getWeight();
                    }
                    iArr[hammingDistance] = iArr[hammingDistance] + 1;
                    i3++;
                    while (i2 > 0 && i3 - iArr[i2 - 1] > this.searchSize) {
                        i2--;
                        i3 -= iArr[i2];
                    }
                    if (this.hashLimitStrategy >= VectorSimilarityMeasure.NO_NORM) {
                        while (i2 < MAX_HASH_LIMIT && onlineSummarizerArr[i2].getCount() > 10 && ((1.0d - this.hashLimitStrategy) * onlineSummarizerArr[i2].getQuartile(0)) + (this.hashLimitStrategy * onlineSummarizerArr[i2].getQuartile(1)) < d) {
                            i3 += iArr[i2];
                            i2++;
                        }
                    }
                }
            }
        }
        return candidateQueue;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v11, types: [java.util.List] */
    @Override // org.apache.mahout.math.neighborhood.Searcher
    public List<WeightedThing<Vector>> search(Vector vector, int i) {
        PriorityQueue<WeightedThing<Vector>> searchInternal = searchInternal(vector);
        ArrayList newArrayListWithExpectedSize = Lists.newArrayListWithExpectedSize(searchInternal.size());
        while (searchInternal.size() != 0) {
            WeightedThing weightedThing = (WeightedThing) searchInternal.pop();
            newArrayListWithExpectedSize.add(new WeightedThing(((HashedVector) weightedThing.getValue()).getVector(), weightedThing.getWeight()));
        }
        Collections.reverse(newArrayListWithExpectedSize);
        if (i < newArrayListWithExpectedSize.size()) {
            newArrayListWithExpectedSize = newArrayListWithExpectedSize.subList(0, i);
        }
        return newArrayListWithExpectedSize;
    }

    @Override // org.apache.mahout.math.neighborhood.Searcher
    public WeightedThing<Vector> searchFirst(Vector vector, boolean z) {
        PriorityQueue<WeightedThing<Vector>> searchInternal = searchInternal(vector);
        while (searchInternal.size() > 2) {
            searchInternal.pop();
        }
        if (searchInternal.size() < 2) {
            return removeHash((WeightedThing) searchInternal.pop());
        }
        WeightedThing weightedThing = (WeightedThing) searchInternal.pop();
        WeightedThing weightedThing2 = (WeightedThing) searchInternal.pop();
        if (z && ((Vector) weightedThing2.getValue()).equals(vector)) {
            weightedThing2 = weightedThing;
        }
        return removeHash(weightedThing2);
    }

    protected static WeightedThing<Vector> removeHash(WeightedThing<Vector> weightedThing) {
        return new WeightedThing<>(((HashedVector) weightedThing.getValue()).getVector(), weightedThing.getWeight());
    }

    @Override // org.apache.mahout.math.neighborhood.Searcher
    public void add(Vector vector) {
        initialize(vector.size());
        this.trainingVectors.add(new HashedVector(vector, this.projection, -1, BIT_MASK));
    }

    @Override // org.apache.mahout.math.neighborhood.Searcher
    public int size() {
        return this.trainingVectors.size();
    }

    public int getSearchSize() {
        return this.searchSize;
    }

    public void setSearchSize(int i) {
        this.searchSize = i;
    }

    public void setRaiseHashLimitStrategy(double d) {
        this.hashLimitStrategy = d;
    }

    public int resetEvaluationCount() {
        int i = this.distanceEvaluations;
        this.distanceEvaluations = 0;
        return i;
    }

    @Override // java.lang.Iterable
    public Iterator<Vector> iterator() {
        return Iterators.transform(this.trainingVectors.iterator(), new Function<HashedVector, Vector>() { // from class: org.apache.mahout.math.neighborhood.LocalitySensitiveHashSearch.1
            public Vector apply(HashedVector hashedVector) {
                Preconditions.checkNotNull(hashedVector);
                return hashedVector.getVector();
            }
        });
    }

    @Override // org.apache.mahout.math.neighborhood.UpdatableSearcher, org.apache.mahout.math.neighborhood.Searcher
    public boolean remove(Vector vector, double d) {
        return this.trainingVectors.remove(new HashedVector(vector, this.projection, -1, BIT_MASK));
    }

    @Override // org.apache.mahout.math.neighborhood.UpdatableSearcher, org.apache.mahout.math.neighborhood.Searcher
    public void clear() {
        this.trainingVectors.clear();
    }
}
