package org.apache.mahout.classifier.sequencelearning.hmm;

import java.util.Arrays;
import java.util.List;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/classifier/sequencelearning/hmm/HMMUtilsTest.class */
public class HMMUtilsTest extends HMMTestBase {
    private Matrix legal22;
    private Matrix legal23;
    private Matrix legal33;
    private Vector legal2;
    private Matrix illegal22;

    /* JADX WARN: Type inference failed for: r3v1, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r3v3, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r3v5, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r3v9, types: [double[], double[][]] */
    @Override // org.apache.mahout.classifier.sequencelearning.hmm.HMMTestBase, org.apache.mahout.common.MahoutTestCase
    public void setUp() throws Exception {
        super.setUp();
        this.legal22 = new DenseMatrix((double[][]) new double[]{new double[]{0.5d, 0.5d}, new double[]{0.3d, 0.7d}});
        this.legal23 = new DenseMatrix((double[][]) new double[]{new double[]{0.2d, 0.2d, 0.6d}, new double[]{0.3d, 0.3d, 0.4d}});
        this.legal33 = new DenseMatrix((double[][]) new double[]{new double[]{0.1d, 0.1d, 0.8d}, new double[]{0.1d, 0.2d, 0.7d}, new double[]{0.2d, 0.3d, 0.5d}});
        this.legal2 = new DenseVector(new double[]{0.4d, 0.6d});
        this.illegal22 = new DenseMatrix((double[][]) new double[]{new double[]{1.0d, 2.0d}, new double[]{3.0d, 4.0d}});
    }

    @Test
    public void testValidatorLegal() {
        HmmUtils.validate(new HmmModel(this.legal22, this.legal23, this.legal2));
    }

    @Test
    public void testValidatorDimensionError() {
        try {
            HmmUtils.validate(new HmmModel(this.legal33, this.legal23, this.legal2));
            fail();
        } catch (IllegalArgumentException e) {
        }
    }

    @Test
    public void testValidatorIllegelMatrixError() {
        try {
            HmmUtils.validate(new HmmModel(this.illegal22, this.legal23, this.legal2));
            fail();
        } catch (IllegalArgumentException e) {
        }
    }

    @Test
    public void testEncodeStateSequence() {
        int[] encodeStateSequence = HmmUtils.encodeStateSequence(getModel(), Arrays.asList("H1", "H2", "H0", "H3", "H4"), false, -1);
        int[] encodeStateSequence2 = HmmUtils.encodeStateSequence(getModel(), Arrays.asList("O1", "O2", "O4", "O0"), true, -1);
        int[] iArr = {1, 2, 0, 3, -1};
        int[] iArr2 = {1, 2, -1, 0};
        for (int i = 0; i < encodeStateSequence.length; i++) {
            assertEquals(iArr[i], encodeStateSequence[i]);
        }
        for (int i2 = 0; i2 < encodeStateSequence2.length; i2++) {
            assertEquals(iArr2[i2], encodeStateSequence2[i2]);
        }
    }

    @Test
    public void testDecodeStateSequence() {
        List decodeStateSequence = HmmUtils.decodeStateSequence(getModel(), new int[]{1, 2, 0, 3, 10}, false, "unknown");
        List decodeStateSequence2 = HmmUtils.decodeStateSequence(getModel(), new int[]{1, 2, 10, 0}, true, "unknown");
        String[] strArr = {"H1", "H2", "H0", "H3", "unknown"};
        String[] strArr2 = {"O1", "O2", "unknown", "O0"};
        for (int i = 0; i < strArr.length; i++) {
            assertEquals(strArr[i], decodeStateSequence.get(i));
        }
        for (int i2 = 0; i2 < strArr2.length; i2++) {
            assertEquals(strArr2[i2], decodeStateSequence2.get(i2));
        }
    }

    /* JADX WARN: Type inference failed for: r2v3, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r2v5, types: [double[], double[][]] */
    @Test
    public void testNormalizeModel() {
        HmmModel hmmModel = new HmmModel(new DenseMatrix((double[][]) new double[]{new double[]{10.0d, 10.0d}, new double[]{20.0d, 25.0d}}), new DenseMatrix((double[][]) new double[]{new double[]{5.0d, 7.0d}, new double[]{10.0d, 15.0d}}), new DenseVector(new double[]{10.0d, 20.0d}));
        HmmUtils.normalizeModel(hmmModel);
        HmmUtils.validate(hmmModel);
    }

    /* JADX WARN: Type inference failed for: r2v3, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r2v5, types: [double[], double[][]] */
    @Test
    public void testTruncateModel() {
        HmmModel truncateModel = HmmUtils.truncateModel(new HmmModel(new DenseMatrix((double[][]) new double[]{new double[]{0.9998d, 1.0E-4d, 1.0E-4d}, new double[]{1.0E-4d, 0.9998d, 1.0E-4d}, new double[]{1.0E-4d, 1.0E-4d, 0.9998d}}), new DenseMatrix((double[][]) new double[]{new double[]{0.9998d, 1.0E-4d, 1.0E-4d}, new double[]{1.0E-4d, 0.9998d, 1.0E-4d}, new double[]{1.0E-4d, 1.0E-4d, 0.9998d}}), new DenseVector(new double[]{1.0E-4d, 1.0E-4d, 0.9998d})), 0.01d);
        HmmUtils.validate(truncateModel);
        Vector initialProbabilities = truncateModel.getInitialProbabilities();
        Matrix transitionMatrix = truncateModel.getTransitionMatrix();
        Matrix emissionMatrix = truncateModel.getEmissionMatrix();
        int i = 0;
        while (i < truncateModel.getNrOfHiddenStates()) {
            assertEquals(i == 2 ? 1.0d : 0.0d, initialProbabilities.getQuick(i), 1.0E-6d);
            for (int i2 = 0; i2 < truncateModel.getNrOfHiddenStates(); i2++) {
                if (i == i2) {
                    assertEquals(1.0d, transitionMatrix.getQuick(i, i2), 1.0E-6d);
                    assertEquals(1.0d, emissionMatrix.getQuick(i, i2), 1.0E-6d);
                } else {
                    assertEquals(0.0d, transitionMatrix.getQuick(i, i2), 1.0E-6d);
                    assertEquals(0.0d, emissionMatrix.getQuick(i, i2), 1.0E-6d);
                }
            }
            i++;
        }
    }
}
