package org.apache.mahout.classifier.sequencelearning.hmm;

import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/classifier/sequencelearning/hmm/HMMTrainerTest.class */
public class HMMTrainerTest extends HMMTestBase {
    /* JADX WARN: Multi-variable type inference failed */
    @Test
    public void testViterbiTraining() {
        double[] dArr = {new double[]{0.3125d, 0.0625d, 0.3125d, 0.3125d}, new double[]{0.25d, 0.25d, 0.25d, 0.25d}, new double[]{0.5d, 0.071429d, 0.357143d, 0.071429d}, new double[]{0.5d, 0.1d, 0.1d, 0.3d}};
        double[] dArr2 = {new double[]{0.882353d, 0.058824d, 0.058824d}, new double[]{0.333333d, 0.333333d, 0.3333333d}, new double[]{0.076923d, 0.846154d, 0.076923d}, new double[]{0.111111d, 0.111111d, 0.777778d}};
        HmmModel trainViterbi = HmmTrainer.trainViterbi(getModel(), new int[]{1, 0, 2, 2, 0, 0, 1, 1, 1, 0, 2, 0, 1, 0, 0}, 0.5d, 0.1d, 10, false);
        Matrix emissionMatrix = trainViterbi.getEmissionMatrix();
        Matrix transitionMatrix = trainViterbi.getTransitionMatrix();
        for (int i = 0; i < trainViterbi.getNrOfHiddenStates(); i++) {
            for (int i2 = 0; i2 < trainViterbi.getNrOfHiddenStates(); i2++) {
                assertEquals(transitionMatrix.getQuick(i, i2), dArr[i][i2], 1.0E-6d);
            }
            for (int i3 = 0; i3 < trainViterbi.getNrOfOutputStates(); i3++) {
                assertEquals(emissionMatrix.getQuick(i, i3), dArr2[i][i3], 1.0E-6d);
            }
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Test
    public void testScaledViterbiTraining() {
        double[] dArr = {new double[]{0.3125d, 0.0625d, 0.3125d, 0.3125d}, new double[]{0.25d, 0.25d, 0.25d, 0.25d}, new double[]{0.5d, 0.071429d, 0.357143d, 0.071429d}, new double[]{0.5d, 0.1d, 0.1d, 0.3d}};
        double[] dArr2 = {new double[]{0.882353d, 0.058824d, 0.058824d}, new double[]{0.333333d, 0.333333d, 0.3333333d}, new double[]{0.076923d, 0.846154d, 0.076923d}, new double[]{0.111111d, 0.111111d, 0.777778d}};
        HmmModel trainViterbi = HmmTrainer.trainViterbi(getModel(), new int[]{1, 0, 2, 2, 0, 0, 1, 1, 1, 0, 2, 0, 1, 0, 0}, 0.5d, 0.1d, 10, true);
        Matrix emissionMatrix = trainViterbi.getEmissionMatrix();
        Matrix transitionMatrix = trainViterbi.getTransitionMatrix();
        for (int i = 0; i < trainViterbi.getNrOfHiddenStates(); i++) {
            for (int i2 = 0; i2 < trainViterbi.getNrOfHiddenStates(); i2++) {
                assertEquals(transitionMatrix.getQuick(i, i2), dArr[i][i2], 1.0E-6d);
            }
            for (int i3 = 0; i3 < trainViterbi.getNrOfOutputStates(); i3++) {
                assertEquals(emissionMatrix.getQuick(i, i3), dArr2[i][i3], 1.0E-6d);
            }
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Test
    public void testBaumWelchTraining() {
        double[] dArr = {0.0d, 0.0d, 1.0d, 0.0d};
        double[] dArr2 = {new double[]{0.2319d, 0.0993d, 5.0E-4d, 0.6683d}, new double[]{1.0E-4d, 0.3345d, 0.6654d, 0.0d}, new double[]{0.5975d, 0.0d, 0.4025d, 0.0d}, new double[]{0.0024d, 0.6657d, 0.0d, 0.3319d}};
        double[] dArr3 = {new double[]{0.9995d, 4.0E-4d, 1.0E-4d}, new double[]{0.9943d, 0.0036d, 0.0021d}, new double[]{0.0059d, 0.9941d, 0.0d}, new double[]{0.0d, 0.0d, 1.0d}};
        HmmModel trainBaumWelch = HmmTrainer.trainBaumWelch(getModel(), new int[]{1, 0, 2, 2, 0, 0, 1, 1, 1, 0, 2, 0, 1, 0, 0}, 0.1d, 10, false);
        Vector initialProbabilities = trainBaumWelch.getInitialProbabilities();
        Matrix emissionMatrix = trainBaumWelch.getEmissionMatrix();
        Matrix transitionMatrix = trainBaumWelch.getTransitionMatrix();
        for (int i = 0; i < trainBaumWelch.getNrOfHiddenStates(); i++) {
            assertEquals(initialProbabilities.get(i), dArr[i], 1.0E-4d);
            for (int i2 = 0; i2 < trainBaumWelch.getNrOfHiddenStates(); i2++) {
                assertEquals(transitionMatrix.getQuick(i, i2), dArr2[i][i2], 1.0E-4d);
            }
            for (int i3 = 0; i3 < trainBaumWelch.getNrOfOutputStates(); i3++) {
                assertEquals(emissionMatrix.getQuick(i, i3), dArr3[i][i3], 1.0E-4d);
            }
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Test
    public void testScaledBaumWelchTraining() {
        double[] dArr = {0.0d, 0.0d, 1.0d, 0.0d};
        double[] dArr2 = {new double[]{0.2319d, 0.0993d, 5.0E-4d, 0.6683d}, new double[]{1.0E-4d, 0.3345d, 0.6654d, 0.0d}, new double[]{0.5975d, 0.0d, 0.4025d, 0.0d}, new double[]{0.0024d, 0.6657d, 0.0d, 0.3319d}};
        double[] dArr3 = {new double[]{0.9995d, 4.0E-4d, 1.0E-4d}, new double[]{0.9943d, 0.0036d, 0.0021d}, new double[]{0.0059d, 0.9941d, 0.0d}, new double[]{0.0d, 0.0d, 1.0d}};
        HmmModel trainBaumWelch = HmmTrainer.trainBaumWelch(getModel(), new int[]{1, 0, 2, 2, 0, 0, 1, 1, 1, 0, 2, 0, 1, 0, 0}, 0.1d, 10, true);
        Vector initialProbabilities = trainBaumWelch.getInitialProbabilities();
        Matrix emissionMatrix = trainBaumWelch.getEmissionMatrix();
        Matrix transitionMatrix = trainBaumWelch.getTransitionMatrix();
        for (int i = 0; i < trainBaumWelch.getNrOfHiddenStates(); i++) {
            assertEquals(initialProbabilities.get(i), dArr[i], 1.0E-4d);
            for (int i2 = 0; i2 < trainBaumWelch.getNrOfHiddenStates(); i2++) {
                assertEquals(transitionMatrix.getQuick(i, i2), dArr2[i][i2], 1.0E-4d);
            }
            for (int i3 = 0; i3 < trainBaumWelch.getNrOfOutputStates(); i3++) {
                assertEquals(emissionMatrix.getQuick(i, i3), dArr3[i][i3], 1.0E-4d);
            }
        }
    }
}
