package org.apache.mahout.classifier.sgd;

import com.google.common.collect.Ordering;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Set;
import org.apache.mahout.classifier.AbstractVectorClassifier;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.Vector;

/* loaded from: input_file:org/apache/mahout/classifier/sgd/ModelDissector.class */
public class ModelDissector {
    private final Map<String, Vector> weightMap = new HashMap();

    /* loaded from: input_file:org/apache/mahout/classifier/sgd/ModelDissector$Category.class */
    private static final class Category implements Comparable<Category> {
        private final int index;
        private final double weight;

        private Category(int i, double d) {
            this.index = i;
            this.weight = d;
        }

        @Override // java.lang.Comparable
        public int compareTo(Category category) {
            int compare = Double.compare(Math.abs(this.weight), Math.abs(category.weight));
            if (compare != 0) {
                return compare;
            }
            if (category.index < this.index) {
                return -1;
            }
            return category.index > this.index ? 1 : 0;
        }

        public boolean equals(Object obj) {
            if (!(obj instanceof Category)) {
                return false;
            }
            Category category = (Category) obj;
            return this.index == category.index && this.weight == category.weight;
        }

        public int hashCode() {
            return RandomUtils.hashDouble(this.weight) ^ this.index;
        }
    }

    /* loaded from: input_file:org/apache/mahout/classifier/sgd/ModelDissector$Weight.class */
    public static class Weight implements Comparable<Weight> {
        private final String feature;
        private final double value;
        private final int maxIndex;
        private final List<Category> categories;

        public Weight(String str, Vector vector) {
            this(str, vector, 3);
        }

        public Weight(String str, Vector vector, int i) {
            this.feature = str;
            PriorityQueue priorityQueue = new PriorityQueue(i + 1, Ordering.natural());
            for (Vector.Element element : vector.all()) {
                priorityQueue.add(new Category(element.index(), element.get()));
                while (priorityQueue.size() > i) {
                    priorityQueue.poll();
                }
            }
            this.categories = new ArrayList(priorityQueue);
            Collections.sort(this.categories, Ordering.natural().reverse());
            this.value = this.categories.get(0).weight;
            this.maxIndex = this.categories.get(0).index;
        }

        @Override // java.lang.Comparable
        public int compareTo(Weight weight) {
            int compare = Double.compare(Math.abs(this.value), Math.abs(weight.value));
            return compare == 0 ? this.feature.compareTo(weight.feature) : compare;
        }

        public boolean equals(Object obj) {
            if (!(obj instanceof Weight)) {
                return false;
            }
            Weight weight = (Weight) obj;
            return this.feature.equals(weight.feature) && this.value == weight.value && this.maxIndex == weight.maxIndex && this.categories.equals(weight.categories);
        }

        public int hashCode() {
            return ((this.feature.hashCode() ^ RandomUtils.hashDouble(this.value)) ^ this.maxIndex) ^ this.categories.hashCode();
        }

        public String getFeature() {
            return this.feature;
        }

        public double getWeight() {
            return this.value;
        }

        public double getWeight(int i) {
            return this.categories.get(i).weight;
        }

        public double getCategory(int i) {
            return this.categories.get(i).index;
        }

        public int getMaxImpact() {
            return this.maxIndex;
        }
    }

    public void update(Vector vector, Map<String, Set<Integer>> map, AbstractVectorClassifier abstractVectorClassifier) {
        vector.assign(0.0d);
        for (Map.Entry<String, Set<Integer>> entry : map.entrySet()) {
            String key = entry.getKey();
            Set<Integer> value = entry.getValue();
            if (!this.weightMap.containsKey(key)) {
                Iterator<Integer> it = value.iterator();
                while (it.hasNext()) {
                    vector.set(it.next().intValue(), 1.0d);
                }
                this.weightMap.put(key, abstractVectorClassifier.classifyNoLink(vector));
                Iterator<Integer> it2 = value.iterator();
                while (it2.hasNext()) {
                    vector.set(it2.next().intValue(), 0.0d);
                }
            }
        }
    }

    public List<Weight> summary(int i) {
        PriorityQueue priorityQueue = new PriorityQueue();
        for (Map.Entry<String, Vector> entry : this.weightMap.entrySet()) {
            priorityQueue.add(new Weight(entry.getKey(), entry.getValue()));
            while (priorityQueue.size() > i) {
                priorityQueue.poll();
            }
        }
        ArrayList arrayList = new ArrayList(priorityQueue);
        Collections.sort(arrayList, Ordering.natural().reverse());
        return arrayList;
    }
}
