package hex;

import hex.ModelMetrics;
import hex.ModelMetricsSupervised;
import org.apache.lucene.util.packed.PackedInts;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Frame;
import water.util.ArrayUtils;
import water.util.ModelUtils;

/* loaded from: input_file:hex/ModelMetricsBinomial.class */
public class ModelMetricsBinomial extends ModelMetricsSupervised {
    public final AUCData _aucdata;
    public final ConfusionMatrix _cm;

    /* loaded from: input_file:hex/ModelMetricsBinomial$MetricBuilderBinomial.class */
    public static class MetricBuilderBinomial extends ModelMetricsSupervised.MetricBuilderSupervised {
        final float[] _thresholds;
        long[][][] _cms;
        static final /* synthetic */ boolean $assertionsDisabled;

        public MetricBuilderBinomial(String[] strArr, float[] fArr) {
            super(strArr);
            this._thresholds = fArr;
            if (!$assertionsDisabled && ((this._nclasses != 2 || fArr.length <= 0) && (this._nclasses == 2 || fArr.length != 1))) {
                throw new AssertionError();
            }
            this._cms = new long[fArr.length][this._nclasses][this._nclasses];
        }

        @Override // hex.ModelMetricsSupervised.MetricBuilderSupervised, hex.ModelMetrics.MetricBuilder
        public float[] perRow(float[] fArr, float[] fArr2, Model model) {
            if (!Float.isNaN(fArr2[0]) && !Float.isNaN(fArr[0])) {
                int i = (int) fArr2[0];
                float f = 0.0f;
                for (int i2 = 1; i2 < fArr.length; i2++) {
                    if (!$assertionsDisabled && (PackedInts.COMPACT > fArr[i2] || fArr[i2] > 1.0f)) {
                        throw new AssertionError();
                    }
                    f += fArr[i2];
                }
                if (!$assertionsDisabled && Math.abs(f - 1.0f) >= 1.0E-6d) {
                    throw new AssertionError();
                }
                float f2 = 1.0f - fArr[i + 1];
                this._sumsqe += f2 * f2;
                if (!$assertionsDisabled && Double.isNaN(this._sumsqe)) {
                    throw new AssertionError();
                }
                float f3 = fArr[2];
                for (int i3 = 0; i3 < ModelUtils.DEFAULT_THRESHOLDS.length; i3++) {
                    boolean z = f3 >= ModelUtils.DEFAULT_THRESHOLDS[i3];
                    long[] jArr = this._cms[i3][i];
                    jArr[z ? 1 : 0] = jArr[z ? 1 : 0] + 1;
                }
                this._count++;
                return fArr;
            }
            return fArr;
        }

        @Override // hex.ModelMetrics.MetricBuilder
        public void reduce(ModelMetrics.MetricBuilder metricBuilder) {
            super.reduce(metricBuilder);
            ArrayUtils.add(this._cms, ((MetricBuilderBinomial) metricBuilder)._cms);
        }

        @Override // hex.ModelMetricsSupervised.MetricBuilderSupervised, hex.ModelMetrics.MetricBuilder
        public ModelMetrics makeModelMetrics(Model model, Frame frame, double d) {
            ConfusionMatrix[] confusionMatrixArr = new ConfusionMatrix[this._cms.length];
            for (int i = 0; i < confusionMatrixArr.length; i++) {
                confusionMatrixArr[i] = new ConfusionMatrix(this._cms[i], this._domain);
            }
            return model._output.addModelMetrics(new ModelMetricsBinomial(model, frame, new AUC(confusionMatrixArr, this._thresholds, this._domain).data(), d, this._sumsqe / this._count));
        }

        static {
            $assertionsDisabled = !ModelMetricsBinomial.class.desiredAssertionStatus();
        }
    }

    public ModelMetricsBinomial(Model model, Frame frame) {
        super(model, frame);
        this._aucdata = null;
        this._cm = null;
    }

    public ModelMetricsBinomial(Model model, Frame frame, AUCData aUCData, double d, double d2) {
        super(model, frame);
        this._aucdata = aUCData;
        this._cm = aUCData.CM();
        this._sigma = d;
        this._mse = d2;
    }

    @Override // hex.ModelMetrics
    public ConfusionMatrix cm() {
        return this._cm;
    }

    @Override // hex.ModelMetrics
    public AUCData auc() {
        return this._aucdata;
    }

    public static ModelMetricsBinomial getFromDKV(Model model, Frame frame) {
        ModelMetrics fromDKV = ModelMetrics.getFromDKV(model, frame);
        if (fromDKV instanceof ModelMetricsBinomial) {
            return (ModelMetricsBinomial) fromDKV;
        }
        throw new H2OIllegalArgumentException("Expected to find a Binomial ModelMetrics for model: " + model._key.toString() + " and frame: " + frame._key.toString(), "Expected to find a ModelMetricsBinomial for model: " + model._key.toString() + " and frame: " + frame._key.toString() + " but found a: " + fromDKV.getClass());
    }
}
