package org.apache.mahout.classifier.sequencelearning.hmm;

import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;

/* loaded from: input_file:org/apache/mahout/classifier/sequencelearning/hmm/HmmAlgorithms.class */
public final class HmmAlgorithms {
    private HmmAlgorithms() {
    }

    public static Matrix forwardAlgorithm(HmmModel hmmModel, int[] iArr, boolean z) {
        DenseMatrix denseMatrix = new DenseMatrix(iArr.length, hmmModel.getNrOfHiddenStates());
        forwardAlgorithm(denseMatrix, hmmModel, iArr, z);
        return denseMatrix;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void forwardAlgorithm(Matrix matrix, HmmModel hmmModel, int[] iArr, boolean z) {
        Vector initialProbabilities = hmmModel.getInitialProbabilities();
        Matrix emissionMatrix = hmmModel.getEmissionMatrix();
        Matrix transitionMatrix = hmmModel.getTransitionMatrix();
        if (!z) {
            for (int i = 0; i < hmmModel.getNrOfHiddenStates(); i++) {
                matrix.setQuick(0, i, initialProbabilities.getQuick(i) * emissionMatrix.getQuick(i, iArr[0]));
            }
            for (int i2 = 1; i2 < iArr.length; i2++) {
                for (int i3 = 0; i3 < hmmModel.getNrOfHiddenStates(); i3++) {
                    double d = 0.0d;
                    for (int i4 = 0; i4 < hmmModel.getNrOfHiddenStates(); i4++) {
                        d += matrix.getQuick(i2 - 1, i4) * transitionMatrix.getQuick(i4, i3);
                    }
                    matrix.setQuick(i2, i3, d * emissionMatrix.getQuick(i3, iArr[i2]));
                }
            }
            return;
        }
        for (int i5 = 0; i5 < hmmModel.getNrOfHiddenStates(); i5++) {
            matrix.setQuick(0, i5, Math.log(initialProbabilities.getQuick(i5) * emissionMatrix.getQuick(i5, iArr[0])));
        }
        for (int i6 = 1; i6 < iArr.length; i6++) {
            for (int i7 = 0; i7 < hmmModel.getNrOfHiddenStates(); i7++) {
                double d2 = Double.NEGATIVE_INFINITY;
                for (int i8 = 0; i8 < hmmModel.getNrOfHiddenStates(); i8++) {
                    double quick = matrix.getQuick(i6 - 1, i8) + Math.log(transitionMatrix.getQuick(i8, i7));
                    if (quick > Double.NEGATIVE_INFINITY) {
                        d2 = quick + Math.log1p(Math.exp(d2 - quick));
                    }
                }
                matrix.setQuick(i6, i7, d2 + Math.log(emissionMatrix.getQuick(i7, iArr[i6])));
            }
        }
    }

    public static Matrix backwardAlgorithm(HmmModel hmmModel, int[] iArr, boolean z) {
        DenseMatrix denseMatrix = new DenseMatrix(iArr.length, hmmModel.getNrOfHiddenStates());
        backwardAlgorithm(denseMatrix, hmmModel, iArr, z);
        return denseMatrix;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void backwardAlgorithm(Matrix matrix, HmmModel hmmModel, int[] iArr, boolean z) {
        Matrix emissionMatrix = hmmModel.getEmissionMatrix();
        Matrix transitionMatrix = hmmModel.getTransitionMatrix();
        if (!z) {
            for (int i = 0; i < hmmModel.getNrOfHiddenStates(); i++) {
                matrix.setQuick(iArr.length - 1, i, 1.0d);
            }
            for (int length = iArr.length - 2; length >= 0; length--) {
                for (int i2 = 0; i2 < hmmModel.getNrOfHiddenStates(); i2++) {
                    double d = 0.0d;
                    for (int i3 = 0; i3 < hmmModel.getNrOfHiddenStates(); i3++) {
                        d += matrix.getQuick(length + 1, i3) * transitionMatrix.getQuick(i2, i3) * emissionMatrix.getQuick(i3, iArr[length + 1]);
                    }
                    matrix.setQuick(length, i2, d);
                }
            }
            return;
        }
        for (int i4 = 0; i4 < hmmModel.getNrOfHiddenStates(); i4++) {
            matrix.setQuick(iArr.length - 1, i4, 0.0d);
        }
        for (int length2 = iArr.length - 2; length2 >= 0; length2--) {
            for (int i5 = 0; i5 < hmmModel.getNrOfHiddenStates(); i5++) {
                double d2 = Double.NEGATIVE_INFINITY;
                for (int i6 = 0; i6 < hmmModel.getNrOfHiddenStates(); i6++) {
                    double quick = matrix.getQuick(length2 + 1, i6) + Math.log(transitionMatrix.getQuick(i5, i6)) + Math.log(emissionMatrix.getQuick(i6, iArr[length2 + 1]));
                    if (quick > Double.NEGATIVE_INFINITY) {
                        d2 = quick + Math.log1p(Math.exp(d2 - quick));
                    }
                }
                matrix.setQuick(length2, i5, d2);
            }
        }
    }

    public static int[] viterbiAlgorithm(HmmModel hmmModel, int[] iArr, boolean z) {
        double[][] dArr = new double[iArr.length][hmmModel.getNrOfHiddenStates()];
        int[][] iArr2 = new int[iArr.length - 1][hmmModel.getNrOfHiddenStates()];
        int[] iArr3 = new int[iArr.length];
        viterbiAlgorithm(iArr3, dArr, iArr2, hmmModel, iArr, z);
        return iArr3;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void viterbiAlgorithm(int[] iArr, double[][] dArr, int[][] iArr2, HmmModel hmmModel, int[] iArr3, boolean z) {
        Vector initialProbabilities = hmmModel.getInitialProbabilities();
        Matrix emissionMatrix = hmmModel.getEmissionMatrix();
        Matrix transitionMatrix = hmmModel.getTransitionMatrix();
        if (z) {
            for (int i = 0; i < hmmModel.getNrOfHiddenStates(); i++) {
                dArr[0][i] = Math.log(initialProbabilities.getQuick(i) * emissionMatrix.getQuick(i, iArr3[0]));
            }
        } else {
            for (int i2 = 0; i2 < hmmModel.getNrOfHiddenStates(); i2++) {
                dArr[0][i2] = initialProbabilities.getQuick(i2) * emissionMatrix.getQuick(i2, iArr3[0]);
            }
        }
        if (z) {
            for (int i3 = 1; i3 < iArr3.length; i3++) {
                for (int i4 = 0; i4 < hmmModel.getNrOfHiddenStates(); i4++) {
                    int i5 = 0;
                    double log = dArr[i3 - 1][0] + Math.log(transitionMatrix.getQuick(0, i4));
                    for (int i6 = 1; i6 < hmmModel.getNrOfHiddenStates(); i6++) {
                        double log2 = dArr[i3 - 1][i6] + Math.log(transitionMatrix.getQuick(i6, i4));
                        if (log2 > log) {
                            log = log2;
                            i5 = i6;
                        }
                    }
                    dArr[i3][i4] = log + Math.log(emissionMatrix.getQuick(i4, iArr3[i3]));
                    iArr2[i3 - 1][i4] = i5;
                }
            }
        } else {
            for (int i7 = 1; i7 < iArr3.length; i7++) {
                for (int i8 = 0; i8 < hmmModel.getNrOfHiddenStates(); i8++) {
                    int i9 = 0;
                    double quick = dArr[i7 - 1][0] * transitionMatrix.getQuick(0, i8);
                    for (int i10 = 1; i10 < hmmModel.getNrOfHiddenStates(); i10++) {
                        double quick2 = dArr[i7 - 1][i10] * transitionMatrix.getQuick(i10, i8);
                        if (quick2 > quick) {
                            quick = quick2;
                            i9 = i10;
                        }
                    }
                    dArr[i7][i8] = quick * emissionMatrix.getQuick(i8, iArr3[i7]);
                    iArr2[i7 - 1][i8] = i9;
                }
            }
        }
        double d = z ? Double.NEGATIVE_INFINITY : 0.0d;
        for (int i11 = 0; i11 < hmmModel.getNrOfHiddenStates(); i11++) {
            if (dArr[iArr3.length - 1][i11] > d) {
                d = dArr[iArr3.length - 1][i11];
                iArr[iArr3.length - 1] = i11;
            }
        }
        for (int length = iArr3.length - 2; length >= 0; length--) {
            iArr[length] = iArr2[length][iArr[length + 1]];
        }
    }
}
