package org.apache.hadoop.hive.ql.udf.generic;

import java.util.ArrayList;
import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.Text;

@Description(name = "ngrams", value = "_FUNC_(expr, n, k, pf) - Estimates the top-k n-grams in rows that consist of sequences of strings, represented as arrays of strings, or arrays of arrays of strings. 'pf' is an optional precision factor that controls memory usage.", extended = "The parameter 'n' specifies what type of n-grams are being estimated. Unigrams are n = 1, and bigrams are n = 2. Generally, n will not be greater than about 5. The 'k' parameter specifies how many of the highest-frequency n-grams will be returned by the UDAF. The optional precision factor 'pf' specifies how much memory to use for estimation; more memory will give more accurate frequency counts, but could crash the JVM. The default value is 20, which internally maintains 20*k n-grams, but only returns the k highest frequency ones. The output is an array of structs with the top-k n-grams. It might be convenient to explode() the output of this UDAF.")
/* loaded from: input_file:WEB-INF/lib/hive-exec-1.2.0-mapr-1710-r3.jar:org/apache/hadoop/hive/ql/udf/generic/GenericUDAFnGrams.class */
public class GenericUDAFnGrams implements GenericUDAFResolver {
    static final Log LOG = LogFactory.getLog(GenericUDAFnGrams.class.getName());

    /* loaded from: input_file:WEB-INF/lib/hive-exec-1.2.0-mapr-1710-r3.jar:org/apache/hadoop/hive/ql/udf/generic/GenericUDAFnGrams$GenericUDAFnGramEvaluator.class */
    public static class GenericUDAFnGramEvaluator extends GenericUDAFEvaluator {
        private transient ListObjectInspector outerInputOI;
        private transient StandardListObjectInspector innerInputOI;
        private transient PrimitiveObjectInspector inputOI;
        private transient PrimitiveObjectInspector nOI;
        private transient PrimitiveObjectInspector kOI;
        private transient PrimitiveObjectInspector pOI;
        private transient ListObjectInspector loi;
        static final /* synthetic */ boolean $assertionsDisabled;

        /* JADX INFO: Access modifiers changed from: package-private */
        /* loaded from: input_file:WEB-INF/lib/hive-exec-1.2.0-mapr-1710-r3.jar:org/apache/hadoop/hive/ql/udf/generic/GenericUDAFnGrams$GenericUDAFnGramEvaluator$NGramAggBuf.class */
        public static class NGramAggBuf extends GenericUDAFEvaluator.AbstractAggregationBuffer {
            NGramEstimator nge;
            int n;

            NGramAggBuf() {
            }
        }

        @Override // org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator
        public ObjectInspector init(GenericUDAFEvaluator.Mode mode, ObjectInspector[] objectInspectorArr) throws HiveException {
            super.init(mode, objectInspectorArr);
            if (mode == GenericUDAFEvaluator.Mode.PARTIAL1 || mode == GenericUDAFEvaluator.Mode.COMPLETE) {
                this.outerInputOI = (ListObjectInspector) objectInspectorArr[0];
                if (this.outerInputOI.getListElementObjectInspector().getCategory() == ObjectInspector.Category.LIST) {
                    this.innerInputOI = (StandardListObjectInspector) this.outerInputOI.getListElementObjectInspector();
                    this.inputOI = (PrimitiveObjectInspector) this.innerInputOI.getListElementObjectInspector();
                } else {
                    this.inputOI = (PrimitiveObjectInspector) this.outerInputOI.getListElementObjectInspector();
                    this.innerInputOI = null;
                }
                this.nOI = (PrimitiveObjectInspector) objectInspectorArr[1];
                this.kOI = (PrimitiveObjectInspector) objectInspectorArr[2];
                if (objectInspectorArr.length == 4) {
                    this.pOI = (PrimitiveObjectInspector) objectInspectorArr[3];
                } else {
                    this.pOI = null;
                }
            } else {
                this.loi = (ListObjectInspector) objectInspectorArr[0];
            }
            if (mode == GenericUDAFEvaluator.Mode.PARTIAL1 || mode == GenericUDAFEvaluator.Mode.PARTIAL2) {
                return ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
            }
            ArrayList arrayList = new ArrayList();
            arrayList.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableStringObjectInspector));
            arrayList.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
            ArrayList arrayList2 = new ArrayList();
            arrayList2.add("ngram");
            arrayList2.add("estfrequency");
            return ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorFactory.getStandardStructObjectInspector(arrayList2, arrayList));
        }

        @Override // org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator
        public void merge(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer, Object obj) throws HiveException {
            if (obj == null) {
                return;
            }
            NGramAggBuf nGramAggBuf = (NGramAggBuf) aggregationBuffer;
            List<?> list = this.loi.getList(obj);
            int parseInt = Integer.parseInt(list.get(list.size() - 1).toString());
            if (parseInt == 0) {
                return;
            }
            if (nGramAggBuf.n > 0 && nGramAggBuf.n != parseInt) {
                throw new HiveException(getClass().getSimpleName() + ": mismatch in value for 'n', which usually is caused by a non-constant expression. Found '" + parseInt + "' and '" + nGramAggBuf.n + "'.");
            }
            nGramAggBuf.n = parseInt;
            list.remove(list.size() - 1);
            nGramAggBuf.nge.merge(list);
        }

        @Override // org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator
        public Object terminatePartial(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer) throws HiveException {
            NGramAggBuf nGramAggBuf = (NGramAggBuf) aggregationBuffer;
            ArrayList<Text> serialize = nGramAggBuf.nge.serialize();
            serialize.add(new Text(Integer.toString(nGramAggBuf.n)));
            return serialize;
        }

        private void processNgrams(NGramAggBuf nGramAggBuf, ArrayList<String> arrayList) throws HiveException {
            for (int size = arrayList.size() - nGramAggBuf.n; size >= 0; size--) {
                ArrayList<String> arrayList2 = new ArrayList<>();
                for (int i = 0; i < nGramAggBuf.n; i++) {
                    arrayList2.add(arrayList.get(size + i));
                }
                nGramAggBuf.nge.add(arrayList2);
            }
        }

        @Override // org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator
        public void iterate(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer, Object[] objArr) throws HiveException {
            int i;
            if (!$assertionsDisabled && objArr.length != 3 && objArr.length != 4) {
                throw new AssertionError();
            }
            if (objArr[0] == null || objArr[1] == null || objArr[2] == null) {
                return;
            }
            NGramAggBuf nGramAggBuf = (NGramAggBuf) aggregationBuffer;
            if (!nGramAggBuf.nge.isInitialized()) {
                int i2 = PrimitiveObjectInspectorUtils.getInt(objArr[1], this.nOI);
                int i3 = PrimitiveObjectInspectorUtils.getInt(objArr[2], this.kOI);
                if (i2 < 1) {
                    throw new HiveException(getClass().getSimpleName() + " needs 'n' to be at least 1, but you supplied " + i2);
                }
                if (i3 < 1) {
                    throw new HiveException(getClass().getSimpleName() + " needs 'k' to be at least 1, but you supplied " + i3);
                }
                if (objArr.length == 4) {
                    i = PrimitiveObjectInspectorUtils.getInt(objArr[3], this.pOI);
                    if (i < 1) {
                        throw new HiveException(getClass().getSimpleName() + " needs 'pf' to be at least 1, but you supplied " + i);
                    }
                } else {
                    i = 1;
                }
                nGramAggBuf.n = i2;
                nGramAggBuf.nge.initialize(i3, i, i2);
            }
            List<?> list = this.outerInputOI.getList(objArr[0]);
            if (this.innerInputOI == null) {
                ArrayList<String> arrayList = new ArrayList<>();
                for (int i4 = 0; i4 < list.size(); i4++) {
                    arrayList.add(PrimitiveObjectInspectorUtils.getString(list.get(i4), this.inputOI));
                }
                processNgrams(nGramAggBuf, arrayList);
                return;
            }
            for (int i5 = 0; i5 < list.size(); i5++) {
                List<?> list2 = this.innerInputOI.getList(list.get(i5));
                ArrayList<String> arrayList2 = new ArrayList<>();
                for (int i6 = 0; i6 < list2.size(); i6++) {
                    arrayList2.add(PrimitiveObjectInspectorUtils.getString(list2.get(i6), this.inputOI));
                }
                processNgrams(nGramAggBuf, arrayList2);
            }
        }

        @Override // org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator
        public Object terminate(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer) throws HiveException {
            return ((NGramAggBuf) aggregationBuffer).nge.getNGrams();
        }

        @Override // org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator
        public GenericUDAFEvaluator.AggregationBuffer getNewAggregationBuffer() throws HiveException {
            NGramAggBuf nGramAggBuf = new NGramAggBuf();
            nGramAggBuf.nge = new NGramEstimator();
            reset(nGramAggBuf);
            return nGramAggBuf;
        }

        @Override // org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator
        public void reset(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer) throws HiveException {
            NGramAggBuf nGramAggBuf = (NGramAggBuf) aggregationBuffer;
            nGramAggBuf.nge.reset();
            nGramAggBuf.n = 0;
        }

        static {
            $assertionsDisabled = !GenericUDAFnGrams.class.desiredAssertionStatus();
        }
    }

    @Override // org.apache.hadoop.hive.ql.udf.generic.GenericUDAFResolver
    public GenericUDAFEvaluator getEvaluator(TypeInfo[] typeInfoArr) throws SemanticException {
        PrimitiveTypeInfo primitiveTypeInfo;
        if (typeInfoArr.length != 3 && typeInfoArr.length != 4) {
            throw new UDFArgumentTypeException(typeInfoArr.length - 1, "Please specify either three or four arguments.");
        }
        if (typeInfoArr[0].getCategory() != ObjectInspector.Category.LIST) {
            throw new UDFArgumentTypeException(0, "Only list type arguments are accepted but " + typeInfoArr[0].getTypeName() + " was passed as parameter 1.");
        }
        switch (((ListTypeInfo) typeInfoArr[0]).getListElementTypeInfo().getCategory()) {
            case PRIMITIVE:
                primitiveTypeInfo = (PrimitiveTypeInfo) ((ListTypeInfo) typeInfoArr[0]).getListElementTypeInfo();
                break;
            case LIST:
                primitiveTypeInfo = (PrimitiveTypeInfo) ((ListTypeInfo) ((ListTypeInfo) typeInfoArr[0]).getListElementTypeInfo()).getListElementTypeInfo();
                break;
            default:
                throw new UDFArgumentTypeException(0, "Only arrays of strings or arrays of arrays of strings are accepted but " + typeInfoArr[0].getTypeName() + " was passed as parameter 1.");
        }
        if (primitiveTypeInfo.getPrimitiveCategory() != PrimitiveObjectInspector.PrimitiveCategory.STRING) {
            throw new UDFArgumentTypeException(0, "Only array<string> or array<array<string>> is allowed, but " + typeInfoArr[0].getTypeName() + " was passed as parameter 1.");
        }
        if (typeInfoArr[1].getCategory() != ObjectInspector.Category.PRIMITIVE) {
            throw new UDFArgumentTypeException(1, "Only integers are accepted but " + typeInfoArr[1].getTypeName() + " was passed as parameter 2.");
        }
        switch (((PrimitiveTypeInfo) typeInfoArr[1]).getPrimitiveCategory()) {
            case BYTE:
            case SHORT:
            case INT:
            case LONG:
            case TIMESTAMP:
                if (typeInfoArr[2].getCategory() != ObjectInspector.Category.PRIMITIVE) {
                    throw new UDFArgumentTypeException(2, "Only integers are accepted but " + typeInfoArr[2].getTypeName() + " was passed as parameter 3.");
                }
                switch (((PrimitiveTypeInfo) typeInfoArr[2]).getPrimitiveCategory()) {
                    case BYTE:
                    case SHORT:
                    case INT:
                    case LONG:
                    case TIMESTAMP:
                        if (typeInfoArr.length == 4) {
                            if (typeInfoArr[3].getCategory() != ObjectInspector.Category.PRIMITIVE) {
                                throw new UDFArgumentTypeException(3, "Only integers are accepted but " + typeInfoArr[3].getTypeName() + " was passed as parameter 4.");
                            }
                            switch (((PrimitiveTypeInfo) typeInfoArr[3]).getPrimitiveCategory()) {
                                case BYTE:
                                case SHORT:
                                case INT:
                                case LONG:
                                case TIMESTAMP:
                                    break;
                                default:
                                    throw new UDFArgumentTypeException(3, "Only integers are accepted but " + typeInfoArr[3].getTypeName() + " was passed as parameter 4.");
                            }
                        }
                        return new GenericUDAFnGramEvaluator();
                    default:
                        throw new UDFArgumentTypeException(2, "Only integers are accepted but " + typeInfoArr[2].getTypeName() + " was passed as parameter 3.");
                }
            default:
                throw new UDFArgumentTypeException(1, "Only integers are accepted but " + typeInfoArr[1].getTypeName() + " was passed as parameter 2.");
        }
    }
}
