package org.apache.mahout.classifier.df.tools;

import java.lang.reflect.Field;
import java.text.DecimalFormat;
import java.util.HashMap;
import java.util.Map;
import org.apache.commons.configuration.tree.DefaultExpressionEngine;
import org.apache.commons.lang.ArrayUtils;
import org.apache.mahout.classifier.df.data.Data;
import org.apache.mahout.classifier.df.data.Dataset;
import org.apache.mahout.classifier.df.data.Instance;
import org.apache.mahout.classifier.df.node.CategoricalNode;
import org.apache.mahout.classifier.df.node.Leaf;
import org.apache.mahout.classifier.df.node.Node;
import org.apache.mahout.classifier.df.node.NumericalNode;

/* loaded from: input_file:org/apache/mahout/classifier/df/tools/TreeVisualizer.class */
public final class TreeVisualizer {
    private TreeVisualizer() {
    }

    private static String doubleToString(double d) {
        return new DecimalFormat("0.##").format(d);
    }

    private static String toStringNode(Node node, Dataset dataset, String[] strArr, Map<String, Field> map, int i) throws IllegalAccessException {
        StringBuilder sb = new StringBuilder();
        if (node instanceof CategoricalNode) {
            CategoricalNode categoricalNode = (CategoricalNode) node;
            int intValue = ((Integer) map.get("CategoricalNode.attr").get(categoricalNode)).intValue();
            double[] dArr = (double[]) map.get("CategoricalNode.values").get(categoricalNode);
            Node[] nodeArr = (Node[]) map.get("CategoricalNode.childs").get(categoricalNode);
            String[][] strArr2 = (String[][]) map.get("Dataset.values").get(dataset);
            for (int i2 = 0; i2 < nodeArr.length; i2++) {
                sb.append('\n');
                for (int i3 = 0; i3 < i; i3++) {
                    sb.append("|   ");
                }
                sb.append((strArr == null ? Integer.valueOf(intValue) : strArr[intValue]) + " = " + strArr2[intValue][i2]);
                int indexOf = ArrayUtils.indexOf(dArr, i2);
                if (indexOf >= 0) {
                    sb.append(toStringNode(nodeArr[indexOf], dataset, strArr, map, i + 1));
                }
            }
        } else if (node instanceof NumericalNode) {
            NumericalNode numericalNode = (NumericalNode) node;
            int intValue2 = ((Integer) map.get("NumericalNode.attr").get(numericalNode)).intValue();
            double doubleValue = ((Double) map.get("NumericalNode.split").get(numericalNode)).doubleValue();
            Node node2 = (Node) map.get("NumericalNode.loChild").get(numericalNode);
            Node node3 = (Node) map.get("NumericalNode.hiChild").get(numericalNode);
            sb.append('\n');
            for (int i4 = 0; i4 < i; i4++) {
                sb.append("|   ");
            }
            sb.append((strArr == null ? Integer.valueOf(intValue2) : strArr[intValue2]) + " < " + doubleToString(doubleValue));
            sb.append(toStringNode(node2, dataset, strArr, map, i + 1));
            sb.append('\n');
            for (int i5 = 0; i5 < i; i5++) {
                sb.append("|   ");
            }
            sb.append((strArr == null ? Integer.valueOf(intValue2) : strArr[intValue2]) + " >= " + doubleToString(doubleValue));
            sb.append(toStringNode(node3, dataset, strArr, map, i + 1));
        } else if (node instanceof Leaf) {
            double doubleValue2 = ((Double) map.get("Leaf.label").get((Leaf) node)).doubleValue();
            if (dataset.isNumerical(dataset.getLabelId())) {
                sb.append(" : ").append(doubleToString(doubleValue2));
            } else {
                sb.append(" : ").append(dataset.getLabelString((int) doubleValue2));
            }
        }
        return sb.toString();
    }

    private static Map<String, Field> getReflectMap() throws Exception {
        HashMap hashMap = new HashMap();
        Field declaredField = CategoricalNode.class.getDeclaredField("attr");
        declaredField.setAccessible(true);
        hashMap.put("CategoricalNode.attr", declaredField);
        Field declaredField2 = CategoricalNode.class.getDeclaredField("values");
        declaredField2.setAccessible(true);
        hashMap.put("CategoricalNode.values", declaredField2);
        Field declaredField3 = CategoricalNode.class.getDeclaredField("childs");
        declaredField3.setAccessible(true);
        hashMap.put("CategoricalNode.childs", declaredField3);
        Field declaredField4 = NumericalNode.class.getDeclaredField("attr");
        declaredField4.setAccessible(true);
        hashMap.put("NumericalNode.attr", declaredField4);
        Field declaredField5 = NumericalNode.class.getDeclaredField("split");
        declaredField5.setAccessible(true);
        hashMap.put("NumericalNode.split", declaredField5);
        Field declaredField6 = NumericalNode.class.getDeclaredField("loChild");
        declaredField6.setAccessible(true);
        hashMap.put("NumericalNode.loChild", declaredField6);
        Field declaredField7 = NumericalNode.class.getDeclaredField("hiChild");
        declaredField7.setAccessible(true);
        hashMap.put("NumericalNode.hiChild", declaredField7);
        Field declaredField8 = Leaf.class.getDeclaredField("label");
        declaredField8.setAccessible(true);
        hashMap.put("Leaf.label", declaredField8);
        Field declaredField9 = Dataset.class.getDeclaredField("values");
        declaredField9.setAccessible(true);
        hashMap.put("Dataset.values", declaredField9);
        return hashMap;
    }

    public static String toString(Node node, Dataset dataset, String[] strArr) throws Exception {
        return toStringNode(node, dataset, strArr, getReflectMap(), 0);
    }

    public static void print(Node node, Dataset dataset, String[] strArr) throws Exception {
        System.out.println(toString(node, dataset, strArr));
    }

    private static String toStringPredict(Node node, Instance instance, Dataset dataset, String[] strArr, Map<String, Field> map) throws IllegalAccessException {
        StringBuilder sb = new StringBuilder();
        if (node instanceof CategoricalNode) {
            CategoricalNode categoricalNode = (CategoricalNode) node;
            int intValue = ((Integer) map.get("CategoricalNode.attr").get(categoricalNode)).intValue();
            double[] dArr = (double[]) map.get("CategoricalNode.values").get(categoricalNode);
            Node[] nodeArr = (Node[]) map.get("CategoricalNode.childs").get(categoricalNode);
            String[][] strArr2 = (String[][]) map.get("Dataset.values").get(dataset);
            int indexOf = ArrayUtils.indexOf(dArr, instance.get(intValue));
            if (indexOf >= 0) {
                sb.append((strArr == null ? Integer.valueOf(intValue) : strArr[intValue]) + " = " + strArr2[intValue][(int) instance.get(intValue)]);
                sb.append(" -> ");
                sb.append(toStringPredict(nodeArr[indexOf], instance, dataset, strArr, map));
            }
        } else if (node instanceof NumericalNode) {
            NumericalNode numericalNode = (NumericalNode) node;
            int intValue2 = ((Integer) map.get("NumericalNode.attr").get(numericalNode)).intValue();
            double doubleValue = ((Double) map.get("NumericalNode.split").get(numericalNode)).doubleValue();
            Node node2 = (Node) map.get("NumericalNode.loChild").get(numericalNode);
            Node node3 = (Node) map.get("NumericalNode.hiChild").get(numericalNode);
            if (instance.get(intValue2) < doubleValue) {
                sb.append(DefaultExpressionEngine.DEFAULT_INDEX_START + (strArr == null ? Integer.valueOf(intValue2) : strArr[intValue2]) + " = " + doubleToString(instance.get(intValue2)) + ") < " + doubleToString(doubleValue));
                sb.append(" -> ");
                sb.append(toStringPredict(node2, instance, dataset, strArr, map));
            } else {
                sb.append(DefaultExpressionEngine.DEFAULT_INDEX_START + (strArr == null ? Integer.valueOf(intValue2) : strArr[intValue2]) + " = " + doubleToString(instance.get(intValue2)) + ") >= " + doubleToString(doubleValue));
                sb.append(" -> ");
                sb.append(toStringPredict(node3, instance, dataset, strArr, map));
            }
        } else if (node instanceof Leaf) {
            double doubleValue2 = ((Double) map.get("Leaf.label").get((Leaf) node)).doubleValue();
            if (dataset.isNumerical(dataset.getLabelId())) {
                sb.append(doubleToString(doubleValue2));
            } else {
                sb.append(dataset.getLabelString((int) doubleValue2));
            }
        }
        return sb.toString();
    }

    public static String[] predictTrace(Node node, Data data, String[] strArr) throws Exception {
        Map<String, Field> reflectMap = getReflectMap();
        String[] strArr2 = new String[data.size()];
        for (int i = 0; i < data.size(); i++) {
            strArr2[i] = toStringPredict(node, data.get(i), data.getDataset(), strArr, reflectMap);
        }
        return strArr2;
    }

    public static void predictTracePrint(Node node, Data data, String[] strArr) throws Exception {
        Map<String, Field> reflectMap = getReflectMap();
        for (int i = 0; i < data.size(); i++) {
            System.out.println(toStringPredict(node, data.get(i), data.getDataset(), strArr, reflectMap));
        }
    }
}
