package org.apache.mahout.classifier.df.node;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Arrays;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.mahout.classifier.df.DFUtils;
import org.apache.mahout.classifier.df.data.Instance;
import org.apache.mahout.classifier.df.node.Node;

/* loaded from: input_file:org/apache/mahout/classifier/df/node/CategoricalNode.class */
public class CategoricalNode extends Node {
    private int attr;
    private double[] values;
    private Node[] childs;

    public CategoricalNode() {
    }

    public CategoricalNode(int i, double[] dArr, Node[] nodeArr) {
        this.attr = i;
        this.values = dArr;
        this.childs = nodeArr;
    }

    @Override // org.apache.mahout.classifier.df.node.Node
    public double classify(Instance instance) {
        int indexOf = ArrayUtils.indexOf(this.values, instance.get(this.attr));
        if (indexOf == -1) {
            return Double.NaN;
        }
        return this.childs[indexOf].classify(instance);
    }

    @Override // org.apache.mahout.classifier.df.node.Node
    public long maxDepth() {
        long j = 0;
        for (Node node : this.childs) {
            long maxDepth = node.maxDepth();
            if (maxDepth > j) {
                j = maxDepth;
            }
        }
        return 1 + j;
    }

    @Override // org.apache.mahout.classifier.df.node.Node
    public long nbNodes() {
        long j = 1;
        for (Node node : this.childs) {
            j += node.nbNodes();
        }
        return j;
    }

    @Override // org.apache.mahout.classifier.df.node.Node
    protected Node.Type getType() {
        return Node.Type.CATEGORICAL;
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (!(obj instanceof CategoricalNode)) {
            return false;
        }
        CategoricalNode categoricalNode = (CategoricalNode) obj;
        return this.attr == categoricalNode.attr && Arrays.equals(this.values, categoricalNode.values) && Arrays.equals(this.childs, categoricalNode.childs);
    }

    public int hashCode() {
        int i = this.attr;
        for (double d : this.values) {
            i = (31 * i) + ((int) Double.doubleToLongBits(d));
        }
        for (Node node : this.childs) {
            i = (31 * i) + node.hashCode();
        }
        return i;
    }

    @Override // org.apache.mahout.classifier.df.node.Node
    protected String getString() {
        StringBuilder sb = new StringBuilder();
        for (Node node : this.childs) {
            sb.append(node).append(',');
        }
        return sb.toString();
    }

    @Override // org.apache.hadoop.io.Writable
    public void readFields(DataInput dataInput) throws IOException {
        this.attr = dataInput.readInt();
        this.values = DFUtils.readDoubleArray(dataInput);
        this.childs = DFUtils.readNodeArray(dataInput);
    }

    @Override // org.apache.mahout.classifier.df.node.Node
    protected void writeNode(DataOutput dataOutput) throws IOException {
        dataOutput.writeInt(this.attr);
        DFUtils.writeArray(dataOutput, this.values);
        DFUtils.writeArray(dataOutput, this.childs);
    }
}
