package org.apache.mahout.classifier.df.split;

import java.util.Arrays;
import java.util.Iterator;
import java.util.TreeSet;
import org.apache.commons.math3.stat.descriptive.rank.Percentile;
import org.apache.mahout.classifier.df.data.Data;
import org.apache.mahout.classifier.df.data.DataUtils;
import org.apache.mahout.classifier.df.data.Dataset;
import org.apache.mahout.classifier.df.data.Instance;

/* loaded from: input_file:org/apache/mahout/classifier/df/split/OptIgSplit.class */
public class OptIgSplit extends IgSplit {
    private static final int MAX_NUMERIC_SPLITS = 16;

    @Override // org.apache.mahout.classifier.df.split.IgSplit
    public Split computeSplit(Data data, int i) {
        return data.getDataset().isNumerical(i) ? numericalSplit(data, i) : categoricalSplit(data, i);
    }

    private static Split categoricalSplit(Data data, int i) {
        double[] chooseCategoricalSplitPoints = chooseCategoricalSplitPoints((double[]) data.values(i).clone());
        int nblabels = data.getDataset().nblabels();
        int[][] iArr = new int[chooseCategoricalSplitPoints.length][nblabels];
        int[] iArr2 = new int[nblabels];
        computeFrequencies(data, i, chooseCategoricalSplitPoints, iArr, iArr2);
        int size = data.size();
        double entropy = entropy(iArr2, size);
        double d = 0.0d;
        double d2 = 1.0d / size;
        for (int i2 = 0; i2 < chooseCategoricalSplitPoints.length; i2++) {
            int sum = DataUtils.sum(iArr[i2]);
            d += sum * d2 * entropy(iArr[i2], sum);
        }
        return new Split(i, entropy - d);
    }

    static void computeFrequencies(Data data, int i, double[] dArr, int[][] iArr, int[] iArr2) {
        Dataset dataset = data.getDataset();
        for (int i2 = 0; i2 < data.size(); i2++) {
            Instance instance = data.get(i2);
            int label = (int) dataset.getLabel(instance);
            double d = instance.get(i);
            int i3 = 0;
            while (i3 < dArr.length && d > dArr[i3]) {
                i3++;
            }
            if (i3 < dArr.length) {
                int[] iArr3 = iArr[i3];
                iArr3[label] = iArr3[label] + 1;
            }
            iArr2[label] = iArr2[label] + 1;
        }
    }

    static Split numericalSplit(Data data, int i) {
        double[] dArr = (double[]) data.values(i).clone();
        Arrays.sort(dArr);
        double[] chooseNumericSplitPoints = chooseNumericSplitPoints(dArr);
        int nblabels = data.getDataset().nblabels();
        int[][] iArr = new int[chooseNumericSplitPoints.length][nblabels];
        int[] iArr2 = new int[nblabels];
        int[] iArr3 = new int[nblabels];
        computeFrequencies(data, i, chooseNumericSplitPoints, iArr, iArr2);
        int size = data.size();
        double entropy = entropy(iArr2, size);
        double d = 1.0d / size;
        int i2 = -1;
        double d2 = -1.0d;
        for (int i3 = 0; i3 < chooseNumericSplitPoints.length; i3++) {
            DataUtils.add(iArr3, iArr[i3]);
            DataUtils.dec(iArr2, iArr[i3]);
            int sum = DataUtils.sum(iArr3);
            double entropy2 = entropy - ((sum * d) * entropy(iArr3, sum));
            int sum2 = DataUtils.sum(iArr2);
            double entropy3 = entropy2 - ((sum2 * d) * entropy(iArr2, sum2));
            if (entropy3 > d2) {
                d2 = entropy3;
                i2 = i3;
            }
        }
        if (i2 == -1) {
            throw new IllegalStateException("no best split found !");
        }
        return new Split(i, d2, chooseNumericSplitPoints[i2]);
    }

    private static double[] chooseNumericSplitPoints(double[] dArr) {
        if (dArr.length <= 1) {
            return dArr;
        }
        if (dArr.length <= 17) {
            double[] dArr2 = new double[dArr.length - 1];
            for (int i = 1; i < dArr.length; i++) {
                dArr2[i - 1] = (dArr[i] + dArr[i - 1]) / 2.0d;
            }
            return dArr2;
        }
        Percentile percentile = new Percentile();
        percentile.setData(dArr);
        double[] dArr3 = new double[MAX_NUMERIC_SPLITS];
        for (int i2 = 0; i2 < dArr3.length; i2++) {
            dArr3[i2] = percentile.evaluate(100.0d * ((i2 + 1.0d) / 17.0d));
        }
        return dArr3;
    }

    private static double[] chooseCategoricalSplitPoints(double[] dArr) {
        TreeSet treeSet = new TreeSet();
        for (double d : dArr) {
            treeSet.add(Double.valueOf(d));
        }
        double[] dArr2 = new double[treeSet.size()];
        Iterator it = treeSet.iterator();
        for (int i = 0; i < dArr2.length; i++) {
            dArr2[i] = ((Double) it.next()).doubleValue();
        }
        return dArr2;
    }

    private static double entropy(int[] iArr, int i) {
        if (i == 0) {
            return 0.0d;
        }
        double d = 0.0d;
        for (int i2 : iArr) {
            if (i2 > 0) {
                double d2 = i2 / i;
                d -= d2 * Math.log(d2);
            }
        }
        return d / LOG2;
    }
}
