package org.apache.hadoop.hive.ql.udf.generic;

import org.apache.derby.impl.store.raw.log.LogCounter;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.exec.vector.VectorizedExpressions;
import org.apache.hadoop.hive.ql.exec.vector.expressions.FuncRoundWithNumDigitsDecimalToDecimal;
import org.apache.hadoop.hive.ql.exec.vector.expressions.RoundWithNumDigitsDoubleToDouble;
import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.FuncRoundDecimalToDecimal;
import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.FuncRoundDoubleToDouble;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.io.ByteWritable;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable;
import org.apache.hadoop.hive.serde2.io.ShortWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.AbstractPrimitiveWritableObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableConstantByteObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableConstantIntObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableConstantLongObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableConstantShortObjectInspector;
import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;

@Description(name = "round", value = "_FUNC_(x[, d]) - round x to d decimal places", extended = "Example:\n  > SELECT _FUNC_(12.3456, 1) FROM src LIMIT 1;\n  12.3'")
@VectorizedExpressions({FuncRoundDoubleToDouble.class, RoundWithNumDigitsDoubleToDouble.class, FuncRoundWithNumDigitsDecimalToDecimal.class, FuncRoundDecimalToDecimal.class})
/* loaded from: input_file:WEB-INF/lib/hive-exec-2.3.6-mapr-2110-r5-core.jar:org/apache/hadoop/hive/ql/udf/generic/GenericUDFRound.class */
public class GenericUDFRound extends GenericUDF {
    private transient PrimitiveObjectInspector inputOI;
    private int scale = 0;
    private transient PrimitiveObjectInspector.PrimitiveCategory inputType;
    private transient ObjectInspectorConverters.Converter converterFromString;

    @Override // org.apache.hadoop.hive.ql.udf.generic.GenericUDF
    public ObjectInspector initialize(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        AbstractPrimitiveWritableObjectInspector primitiveWritableObjectInspector;
        if (objectInspectorArr.length < 1 || objectInspectorArr.length > 2) {
            throw new UDFArgumentLengthException("ROUND requires one or two argument, got " + objectInspectorArr.length);
        }
        if (objectInspectorArr[0].getCategory() != ObjectInspector.Category.PRIMITIVE) {
            throw new UDFArgumentTypeException(0, "ROUND input only takes primitive types, got " + objectInspectorArr[0].getTypeName());
        }
        this.inputOI = (PrimitiveObjectInspector) objectInspectorArr[0];
        if (objectInspectorArr.length == 2) {
            if (objectInspectorArr[1].getCategory() != ObjectInspector.Category.PRIMITIVE) {
                throw new UDFArgumentTypeException(1, "ROUND second argument only takes primitive types, got " + objectInspectorArr[1].getTypeName());
            }
            PrimitiveObjectInspector primitiveObjectInspector = (PrimitiveObjectInspector) objectInspectorArr[1];
            switch (primitiveObjectInspector.getPrimitiveCategory()) {
                case VOID:
                    break;
                case BYTE:
                    if (!(primitiveObjectInspector instanceof WritableConstantByteObjectInspector)) {
                        throw new UDFArgumentTypeException(1, getFuncName().toUpperCase() + " second argument only takes constant");
                    }
                    this.scale = ((WritableConstantByteObjectInspector) primitiveObjectInspector).getWritableConstantValue().get();
                    break;
                case SHORT:
                    if (!(primitiveObjectInspector instanceof WritableConstantShortObjectInspector)) {
                        throw new UDFArgumentTypeException(1, getFuncName().toUpperCase() + " second argument only takes constant");
                    }
                    this.scale = ((WritableConstantShortObjectInspector) primitiveObjectInspector).getWritableConstantValue().get();
                    break;
                case INT:
                    if (!(primitiveObjectInspector instanceof WritableConstantIntObjectInspector)) {
                        throw new UDFArgumentTypeException(1, getFuncName().toUpperCase() + " second argument only takes constant");
                    }
                    this.scale = ((WritableConstantIntObjectInspector) primitiveObjectInspector).getWritableConstantValue().get();
                    break;
                case LONG:
                    if (!(primitiveObjectInspector instanceof WritableConstantLongObjectInspector)) {
                        throw new UDFArgumentTypeException(1, getFuncName().toUpperCase() + " second argument only takes constant");
                    }
                    long j = ((WritableConstantLongObjectInspector) primitiveObjectInspector).getWritableConstantValue().get();
                    if (j >= -2147483648L && j <= LogCounter.MAX_LOGFILE_NUMBER) {
                        this.scale = (int) j;
                        break;
                    } else {
                        throw new UDFArgumentException(getFuncName().toUpperCase() + " scale argument out of allowed range");
                    }
                default:
                    throw new UDFArgumentTypeException(1, getFuncName().toUpperCase() + " second argument only takes integer constant");
            }
        }
        this.inputType = this.inputOI.getPrimitiveCategory();
        switch (this.inputType) {
            case VOID:
            case BYTE:
            case SHORT:
            case INT:
            case LONG:
            case FLOAT:
            case DOUBLE:
                primitiveWritableObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(this.inputType);
                break;
            case DECIMAL:
                primitiveWritableObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(getOutputTypeInfo((DecimalTypeInfo) this.inputOI.getTypeInfo(), this.scale));
                break;
            case STRING:
            case VARCHAR:
            case CHAR:
                primitiveWritableObjectInspector = PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.DOUBLE);
                this.converterFromString = ObjectInspectorConverters.getConverter((ObjectInspector) this.inputOI, (ObjectInspector) primitiveWritableObjectInspector);
                break;
            default:
                throw new UDFArgumentTypeException(0, "Only numeric or string group data types are allowed for ROUND function. Got " + this.inputType.name());
        }
        return primitiveWritableObjectInspector;
    }

    private static DecimalTypeInfo getOutputTypeInfo(DecimalTypeInfo decimalTypeInfo, int i) {
        int precision = decimalTypeInfo.precision();
        int scale = decimalTypeInfo.scale();
        int i2 = precision - scale;
        int i3 = i < scale ? i2 + 1 : i2;
        int min = i < 0 ? 0 : Math.min(i, 38);
        return TypeInfoFactory.getDecimalTypeInfo(Math.min(i3 + min, 38), min);
    }

    @Override // org.apache.hadoop.hive.ql.udf.generic.GenericUDF
    public Object evaluate(GenericUDF.DeferredObject[] deferredObjectArr) throws HiveException {
        Object obj;
        if ((deferredObjectArr.length == 2 && (deferredObjectArr[1] == null || deferredObjectArr[1].get() == null)) || deferredObjectArr[0] == null || (obj = deferredObjectArr[0].get()) == null) {
            return null;
        }
        switch (this.inputType) {
            case VOID:
                return null;
            case BYTE:
                return this.scale >= 0 ? (ByteWritable) this.inputOI.getPrimitiveWritableObject(obj) : new ByteWritable((byte) round(r0.get(), this.scale));
            case SHORT:
                return this.scale >= 0 ? (ShortWritable) this.inputOI.getPrimitiveWritableObject(obj) : new ShortWritable((short) round(r0.get(), this.scale));
            case INT:
                return this.scale >= 0 ? (IntWritable) this.inputOI.getPrimitiveWritableObject(obj) : new IntWritable((int) round(r0.get(), this.scale));
            case LONG:
                LongWritable longWritable = (LongWritable) this.inputOI.getPrimitiveWritableObject(obj);
                return this.scale >= 0 ? longWritable : new LongWritable(round(longWritable.get(), this.scale));
            case DECIMAL:
                return round((HiveDecimalWritable) this.inputOI.getPrimitiveWritableObject(obj), this.scale);
            case FLOAT:
                return new FloatWritable((float) round(((FloatWritable) this.inputOI.getPrimitiveWritableObject(obj)).get(), this.scale));
            case DOUBLE:
                return round((DoubleWritable) this.inputOI.getPrimitiveWritableObject(obj), this.scale);
            case STRING:
            case VARCHAR:
            case CHAR:
                DoubleWritable doubleWritable = (DoubleWritable) this.converterFromString.convert(obj);
                if (doubleWritable == null) {
                    return null;
                }
                return round(doubleWritable, this.scale);
            default:
                throw new UDFArgumentTypeException(0, "Only numeric or string group data types are allowed for ROUND function. Got " + this.inputType.name());
        }
    }

    protected HiveDecimalWritable round(HiveDecimalWritable hiveDecimalWritable, int i) {
        HiveDecimalWritable hiveDecimalWritable2 = new HiveDecimalWritable(hiveDecimalWritable);
        hiveDecimalWritable2.mutateSetScale(i, 4);
        return hiveDecimalWritable2;
    }

    protected long round(long j, int i) {
        return RoundUtils.round(j, i);
    }

    protected double round(double d, int i) {
        return RoundUtils.round(d, i);
    }

    protected DoubleWritable round(DoubleWritable doubleWritable, int i) {
        double d = doubleWritable.get();
        return (Double.isNaN(d) || Double.isInfinite(d)) ? new DoubleWritable(d) : new DoubleWritable(RoundUtils.round(d, i));
    }

    @Override // org.apache.hadoop.hive.ql.udf.generic.GenericUDF
    public String getDisplayString(String[] strArr) {
        return getStandardDisplayString("round", strArr);
    }
}
