package org.apache.mahout.classifier.sequencelearning.hmm;

import com.google.common.io.Resources;
import java.io.IOException;
import java.net.URL;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.regex.Pattern;
import org.apache.commons.io.Charsets;
import org.apache.mahout.math.Matrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/mahout/classifier/sequencelearning/hmm/PosTagger.class */
public final class PosTagger {
    private static HmmModel taggingModel;
    private static Map<String, Integer> tagIDs;
    private static int nextTagId;
    private static Map<String, Integer> wordIDs;
    private static List<int[]> hiddenSequences;
    private static List<int[]> observedSequences;
    private static int readLines;
    private static final Logger log = LoggerFactory.getLogger(PosTagger.class);
    private static final Pattern SPACE = Pattern.compile(" ");
    private static final Pattern SPACES = Pattern.compile("[ ]+");
    private static int nextWordId = 1;

    private PosTagger() {
    }

    private static void readFromURL(String str, boolean z) throws IOException {
        hiddenSequences = new LinkedList();
        observedSequences = new LinkedList();
        readLines = 0;
        LinkedList linkedList = new LinkedList();
        LinkedList linkedList2 = new LinkedList();
        for (String str2 : Resources.readLines(new URL(str), Charsets.UTF_8)) {
            if (str2.isEmpty()) {
                int[] iArr = new int[linkedList.size()];
                int[] iArr2 = new int[linkedList2.size()];
                for (int i = 0; i < linkedList.size(); i++) {
                    iArr[i] = ((Integer) linkedList.get(i)).intValue();
                    iArr2[i] = ((Integer) linkedList2.get(i)).intValue();
                }
                hiddenSequences.add(iArr2);
                observedSequences.add(iArr);
                linkedList.clear();
                linkedList2.clear();
            } else {
                readLines++;
                String[] split = SPACE.split(str2);
                if (z) {
                    if (!wordIDs.containsKey(split[0])) {
                        Map<String, Integer> map = wordIDs;
                        String str3 = split[0];
                        int i2 = nextWordId;
                        nextWordId = i2 + 1;
                        map.put(str3, Integer.valueOf(i2));
                    }
                    if (!tagIDs.containsKey(split[1])) {
                        Map<String, Integer> map2 = tagIDs;
                        String str4 = split[1];
                        int i3 = nextTagId;
                        nextTagId = i3 + 1;
                        map2.put(str4, Integer.valueOf(i3));
                    }
                }
                Integer num = wordIDs.get(split[0]);
                Integer num2 = tagIDs.get(split[1]);
                if (num == null) {
                    linkedList.add(0);
                } else {
                    linkedList.add(num);
                }
                if (num2 == null) {
                    linkedList2.add(0);
                } else {
                    linkedList2.add(num2);
                }
            }
        }
        if (linkedList.isEmpty()) {
            return;
        }
        int[] iArr3 = new int[linkedList.size()];
        int[] iArr4 = new int[linkedList2.size()];
        for (int i4 = 0; i4 < linkedList.size(); i4++) {
            iArr3[i4] = ((Integer) linkedList.get(i4)).intValue();
            iArr4[i4] = ((Integer) linkedList2.get(i4)).intValue();
        }
        hiddenSequences.add(iArr4);
        observedSequences.add(iArr3);
    }

    private static void trainModel(String str) throws IOException {
        tagIDs = new HashMap(44);
        wordIDs = new HashMap(19122);
        log.info("Reading and parsing training data file from URL: {}", str);
        long currentTimeMillis = System.currentTimeMillis();
        readFromURL(str, true);
        log.info("Parsing done in {} seconds!", Double.valueOf((System.currentTimeMillis() - currentTimeMillis) / 1000.0d));
        log.info("Read {} lines containing {} sentences with a total of {} distinct words and {} distinct POS tags.", new Object[]{Integer.valueOf(readLines), Integer.valueOf(hiddenSequences.size()), Integer.valueOf(nextWordId - 1), Integer.valueOf(nextTagId - 1)});
        long currentTimeMillis2 = System.currentTimeMillis();
        taggingModel = HmmTrainer.trainSupervisedSequence(nextTagId, nextWordId, hiddenSequences, observedSequences, 0.05d);
        Matrix emissionMatrix = taggingModel.getEmissionMatrix();
        for (int i = 0; i < taggingModel.getNrOfHiddenStates(); i++) {
            emissionMatrix.setQuick(i, 0, 0.1d / taggingModel.getNrOfHiddenStates());
        }
        emissionMatrix.setQuick(tagIDs.get("NNP").intValue(), 0, 1.0d / taggingModel.getNrOfHiddenStates());
        HmmUtils.normalizeModel(taggingModel);
        taggingModel.registerHiddenStateNames(tagIDs);
        taggingModel.registerOutputStateNames(wordIDs);
        log.info("Trained HMM models in {} seconds!", Double.valueOf((System.currentTimeMillis() - currentTimeMillis2) / 1000.0d));
    }

    private static void testModel(String str) throws IOException {
        log.info("Reading and parsing test data file from URL: {}", str);
        long currentTimeMillis = System.currentTimeMillis();
        readFromURL(str, false);
        log.info("Parsing done in {} seconds!", Double.valueOf((System.currentTimeMillis() - currentTimeMillis) / 1000.0d));
        log.info("Read {} lines containing {} sentences.", Integer.valueOf(readLines), Integer.valueOf(hiddenSequences.size()));
        long currentTimeMillis2 = System.currentTimeMillis();
        int i = 0;
        int i2 = 0;
        for (int i3 = 0; i3 < observedSequences.size(); i3++) {
            int[] decode = HmmEvaluator.decode(taggingModel, observedSequences.get(i3), false);
            int[] iArr = hiddenSequences.get(i3);
            for (int i4 = 0; i4 < iArr.length; i4++) {
                i2++;
                if (decode[i4] != iArr[i4]) {
                    i++;
                }
            }
        }
        log.info("POS tagged test file in {} seconds!", Double.valueOf((System.currentTimeMillis() - currentTimeMillis2) / 1000.0d));
        log.info("Tagged the test file with an error rate of: {}", Double.valueOf(i / i2));
    }

    private static List<String> tagSentence(String str) {
        return HmmUtils.decodeStateSequence(taggingModel, HmmEvaluator.decode(taggingModel, HmmUtils.encodeStateSequence(taggingModel, Arrays.asList(SPACES.split(str.replaceAll("[,.!?:;\"]", " $0 ").replaceAll("''", " '' "))), true, 0), false), false, (String) null);
    }

    public static void main(String[] strArr) throws IOException {
        trainModel("http://www.jaist.ac.jp/~hieuxuan/flexcrfs/CoNLL2000-NP/train.txt");
        testModel("http://www.jaist.ac.jp/~hieuxuan/flexcrfs/CoNLL2000-NP/test.txt");
        String[] split = SPACE.split("McDonalds is a huge company with many employees .");
        List<String> tagSentence = tagSentence("McDonalds is a huge company with many employees .");
        for (int i = 0; i < tagSentence.size(); i++) {
            log.info("{}[{}]", split[i], tagSentence.get(i));
        }
    }
}
