/*
 * Decompiled with CFR 0.152.
 */
package org.apache.calcite.sql.validate.implicit;

import java.math.BigDecimal;
import java.util.AbstractList;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlCallBinding;
import org.apache.calcite.sql.SqlFunction;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlLiteral;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.SqlSelect;
import org.apache.calcite.sql.SqlWith;
import org.apache.calcite.sql.fun.SqlCase;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.sql.validate.SqlUserDefinedTableMacro;
import org.apache.calcite.sql.validate.SqlValidator;
import org.apache.calcite.sql.validate.SqlValidatorScope;
import org.apache.calcite.sql.validate.implicit.AbstractTypeCoercion;
import org.apache.calcite.util.Util;

public class TypeCoercionImpl
extends AbstractTypeCoercion {
    public TypeCoercionImpl(SqlValidator validator) {
        super(validator);
    }

    @Override
    public boolean rowTypeCoercion(SqlValidatorScope scope, SqlNode query, int columnIndex, RelDataType targetType) {
        SqlKind kind = query.getKind();
        switch (kind) {
            case SELECT: {
                SqlSelect selectNode = (SqlSelect)query;
                SqlValidatorScope scope1 = this.validator.getSelectScope(selectNode);
                if (!this.coerceColumnType(scope1, selectNode.getSelectList(), columnIndex, targetType)) {
                    return false;
                }
                this.updateInferredColumnType(scope1, query, columnIndex, targetType);
                return true;
            }
            case VALUES: {
                for (SqlNode rowConstructor : ((SqlCall)query).getOperandList()) {
                    if (this.coerceOperandType(scope, (SqlCall)rowConstructor, columnIndex, targetType)) continue;
                    return false;
                }
                this.updateInferredColumnType(scope, query, columnIndex, targetType);
                return true;
            }
            case WITH: {
                SqlNode body = ((SqlWith)query).body;
                return this.rowTypeCoercion(this.validator.getWithScope(query), body, columnIndex, targetType);
            }
            case UNION: 
            case INTERSECT: 
            case EXCEPT: {
                boolean coerced;
                SqlCall operand0 = (SqlCall)((SqlCall)query).operand(0);
                SqlCall operand1 = (SqlCall)((SqlCall)query).operand(1);
                boolean bl = coerced = this.rowTypeCoercion(scope, operand0, columnIndex, targetType) && this.rowTypeCoercion(scope, operand1, columnIndex, targetType);
                if (coerced) {
                    this.updateInferredColumnType(scope, query, columnIndex, targetType);
                }
                return coerced;
            }
        }
        return false;
    }

    @Override
    public boolean binaryArithmeticCoercion(SqlCallBinding binding) {
        SqlOperator operator = binding.getOperator();
        SqlKind kind = operator.getKind();
        boolean coerced = false;
        if (binding.getOperandCount() == 2) {
            RelDataType type1 = binding.getOperandType(0);
            RelDataType type2 = binding.getOperandType(1);
            if ((kind == SqlKind.PLUS || kind == SqlKind.MINUS) && (SqlTypeUtil.isInterval(type1) || SqlTypeUtil.isInterval(type2))) {
                return false;
            }
            if (kind.belongsTo(SqlKind.BINARY_ARITHMETIC)) {
                coerced = this.binaryArithmeticWithStrings(binding, type1, type2);
            }
        }
        return coerced;
    }

    protected boolean binaryArithmeticWithStrings(SqlCallBinding binding, RelDataType left, RelDataType right) {
        if (SqlTypeUtil.isString(left) && SqlTypeUtil.isNumeric(right)) {
            if (SqlTypeUtil.isDecimal(right)) {
                right = SqlTypeUtil.getMaxPrecisionScaleDecimal(this.factory);
            }
            return this.coerceOperandType(binding.getScope(), binding.getCall(), 0, right);
        }
        if (SqlTypeUtil.isNumeric(left) && SqlTypeUtil.isString(right)) {
            if (SqlTypeUtil.isDecimal(left)) {
                left = SqlTypeUtil.getMaxPrecisionScaleDecimal(this.factory);
            }
            return this.coerceOperandType(binding.getScope(), binding.getCall(), 1, left);
        }
        return false;
    }

    @Override
    public boolean binaryComparisonCoercion(SqlCallBinding binding) {
        SqlOperator operator = binding.getOperator();
        SqlKind kind = operator.getKind();
        int operandCnt = binding.getOperandCount();
        boolean coerced = false;
        if (operandCnt == 2) {
            RelDataType commonType;
            RelDataType type1 = binding.getOperandType(0);
            RelDataType type2 = binding.getOperandType(1);
            if (kind.belongsTo(SqlKind.BINARY_EQUALITY)) {
                coerced = this.dateTimeStringEquality(binding, type1, type2) || coerced;
                boolean bl = coerced = this.booleanEquality(binding, type1, type2) || coerced;
            }
            if (kind.belongsTo(SqlKind.BINARY_COMPARISON) && null != (commonType = this.commonTypeForBinaryComparison(type1, type2))) {
                coerced = this.coerceOperandsType(binding.getScope(), binding.getCall(), commonType);
            }
        }
        if (kind == SqlKind.BETWEEN) {
            List<RelDataType> operandTypes = Util.range(operandCnt).stream().map(binding::getOperandType).collect(Collectors.toList());
            RelDataType commonType = this.commonTypeForComparison(operandTypes);
            if (null != commonType) {
                coerced = this.coerceOperandsType(binding.getScope(), binding.getCall(), commonType);
            }
        }
        return coerced;
    }

    protected RelDataType commonTypeForComparison(List<RelDataType> dataTypes) {
        assert (dataTypes.size() > 2);
        RelDataType type1 = dataTypes.get(0);
        RelDataType type2 = dataTypes.get(1);
        boolean allWithSameName = SqlTypeUtil.sameNamedType(type1, type2);
        for (int i = 2; i < dataTypes.size() && allWithSameName; ++i) {
            allWithSameName = SqlTypeUtil.sameNamedType(dataTypes.get(i - 1), dataTypes.get(i));
        }
        if (allWithSameName) {
            return null;
        }
        RelDataType commonType = SqlTypeUtil.sameNamedType(type1, type2) ? this.factory.leastRestrictive(Arrays.asList(type1, type2)) : this.commonTypeForBinaryComparison(type1, type2);
        for (int i = 2; i < dataTypes.size() && commonType != null; ++i) {
            commonType = SqlTypeUtil.sameNamedType(commonType, dataTypes.get(i)) ? this.factory.leastRestrictive(Arrays.asList(commonType, dataTypes.get(i))) : this.commonTypeForBinaryComparison(commonType, dataTypes.get(i));
        }
        return commonType;
    }

    protected boolean dateTimeStringEquality(SqlCallBinding binding, RelDataType left, RelDataType right) {
        if (SqlTypeUtil.isCharacter(left) && SqlTypeUtil.isDatetime(right)) {
            return this.coerceOperandType(binding.getScope(), binding.getCall(), 0, right);
        }
        if (SqlTypeUtil.isCharacter(right) && SqlTypeUtil.isDatetime(left)) {
            return this.coerceOperandType(binding.getScope(), binding.getCall(), 1, left);
        }
        return false;
    }

    protected boolean booleanEquality(SqlCallBinding binding, RelDataType left, RelDataType right) {
        SqlNode lNode = binding.operand(0);
        SqlNode rNode = binding.operand(1);
        if (SqlTypeUtil.isNumeric(left) && SqlTypeUtil.isBoolean(right)) {
            if (lNode.getKind() == SqlKind.LITERAL) {
                BigDecimal val = ((SqlLiteral)lNode).bigDecimalValue();
                if (val.compareTo(BigDecimal.ONE) == 0) {
                    SqlLiteral lNode1 = SqlLiteral.createBoolean(true, SqlParserPos.ZERO);
                    binding.getCall().setOperand(0, lNode1);
                    return true;
                }
                SqlLiteral lNode1 = SqlLiteral.createBoolean(false, SqlParserPos.ZERO);
                binding.getCall().setOperand(0, lNode1);
                return true;
            }
            return this.coerceOperandType(binding.getScope(), binding.getCall(), 1, left);
        }
        if (SqlTypeUtil.isNumeric(right) && SqlTypeUtil.isBoolean(left)) {
            if (rNode.getKind() == SqlKind.LITERAL) {
                BigDecimal val = ((SqlLiteral)rNode).bigDecimalValue();
                if (val.compareTo(BigDecimal.ONE) == 0) {
                    SqlLiteral rNode1 = SqlLiteral.createBoolean(true, SqlParserPos.ZERO);
                    binding.getCall().setOperand(1, rNode1);
                    return true;
                }
                SqlLiteral rNode1 = SqlLiteral.createBoolean(false, SqlParserPos.ZERO);
                binding.getCall().setOperand(1, rNode1);
                return true;
            }
            return this.coerceOperandType(binding.getScope(), binding.getCall(), 0, right);
        }
        return false;
    }

    @Override
    public boolean caseWhenCoercion(SqlCallBinding callBinding) {
        SqlCase caseCall = (SqlCase)callBinding.getCall();
        SqlNodeList thenList = caseCall.getThenOperands();
        ArrayList<RelDataType> argTypes = new ArrayList<RelDataType>();
        for (SqlNode node : thenList) {
            argTypes.add(this.validator.deriveType(callBinding.getScope(), node));
        }
        SqlNode elseOp = caseCall.getElseOperand();
        RelDataType elseOpType = this.validator.deriveType(callBinding.getScope(), caseCall.getElseOperand());
        argTypes.add(elseOpType);
        RelDataType widerType = this.getWiderTypeFor(argTypes, true);
        if (null != widerType) {
            boolean coerced = false;
            for (int i = 0; i < thenList.size(); ++i) {
                coerced = this.coerceColumnType(callBinding.getScope(), thenList, i, widerType) || coerced;
            }
            if (this.needToCast(callBinding.getScope(), elseOp, widerType)) {
                coerced = this.coerceOperandType(callBinding.getScope(), caseCall, 3, widerType) || coerced;
            }
            return coerced;
        }
        return false;
    }

    @Override
    public boolean inOperationCoercion(SqlCallBinding binding) {
        SqlOperator operator = binding.getOperator();
        if (operator.getKind() == SqlKind.IN) {
            int i;
            assert (binding.getOperandCount() == 2);
            RelDataType type1 = binding.getOperandType(0);
            RelDataType type2 = binding.getOperandType(1);
            SqlNode node1 = binding.operand(0);
            SqlNode node2 = binding.operand(1);
            SqlValidatorScope scope = binding.getScope();
            if (type1.isStruct() && type2.isStruct() && type1.getFieldCount() != type2.getFieldCount()) {
                return false;
            }
            int colCount = type1.isStruct() ? type1.getFieldCount() : 1;
            final RelDataType[] argTypes = new RelDataType[]{type1, type2};
            boolean coerced = false;
            ArrayList<RelDataType> widenTypes = new ArrayList<RelDataType>();
            for (i = 0; i < colCount; ++i) {
                final int i2 = i;
                AbstractList<RelDataType> columnIthTypes = new AbstractList<RelDataType>(){

                    @Override
                    public RelDataType get(int index) {
                        return argTypes[index].isStruct() ? argTypes[index].getFieldList().get(i2).getType() : argTypes[index];
                    }

                    @Override
                    public int size() {
                        return argTypes.length;
                    }
                };
                RelDataType widenType = this.commonTypeForBinaryComparison((RelDataType)columnIthTypes.get(0), (RelDataType)columnIthTypes.get(1));
                if (widenType == null) {
                    widenType = this.getTightestCommonType((RelDataType)columnIthTypes.get(0), (RelDataType)columnIthTypes.get(1));
                }
                if (widenType == null) {
                    return false;
                }
                widenTypes.add(widenType);
            }
            assert (widenTypes.size() == colCount);
            for (i = 0; i < widenTypes.size(); ++i) {
                RelDataType desired = (RelDataType)widenTypes.get(i);
                if (node1.getKind() == SqlKind.ROW) {
                    assert (node1 instanceof SqlCall);
                    if (this.coerceOperandType(scope, (SqlCall)node1, i, desired)) {
                        this.updateInferredColumnType(scope, node1, i, (RelDataType)widenTypes.get(i));
                        coerced = true;
                    }
                } else {
                    boolean bl = coerced = this.coerceOperandType(scope, binding.getCall(), 0, desired) || coerced;
                }
                if (node2 instanceof SqlNodeList) {
                    SqlNodeList node3 = (SqlNodeList)node2;
                    boolean listCoerced = false;
                    if (type2.isStruct()) {
                        for (SqlNode node : (SqlNodeList)node2) {
                            assert (node instanceof SqlCall);
                            listCoerced = this.coerceOperandType(scope, (SqlCall)node, i, desired) || listCoerced;
                        }
                        if (!listCoerced) continue;
                        this.updateInferredColumnType(scope, node2, i, desired);
                        continue;
                    }
                    for (int j = 0; j < ((SqlNodeList)node2).size(); ++j) {
                        listCoerced = this.coerceColumnType(scope, node3, j, desired) || listCoerced;
                    }
                    if (!listCoerced) continue;
                    this.updateInferredType(node2, desired);
                    continue;
                }
                SqlValidatorScope scope1 = node2 instanceof SqlSelect ? this.validator.getSelectScope((SqlSelect)node2) : scope;
                coerced = this.rowTypeCoercion(scope1, node2, i, desired) || coerced;
            }
            return coerced;
        }
        return false;
    }

    @Override
    public boolean builtinFunctionCoercion(SqlCallBinding binding, List<RelDataType> operandTypes, List<SqlTypeFamily> expectedFamilies) {
        assert (binding.getOperandCount() == operandTypes.size());
        if (!this.canImplicitTypeCast(operandTypes, expectedFamilies)) {
            return false;
        }
        boolean coerced = false;
        for (int i = 0; i < operandTypes.size(); ++i) {
            RelDataType implicitType = this.implicitCast(operandTypes.get(i), expectedFamilies.get(i));
            coerced = null != implicitType && operandTypes.get(i) != implicitType && this.coerceOperandType(binding.getScope(), binding.getCall(), i, implicitType) || coerced;
        }
        return coerced;
    }

    @Override
    public boolean userDefinedFunctionCoercion(SqlValidatorScope scope, SqlCall call, SqlFunction function) {
        List<RelDataType> paramTypes = function.getParamTypes();
        assert (paramTypes != null);
        boolean coerced = false;
        if (function instanceof SqlUserDefinedTableMacro) {
            return false;
        }
        for (int i = 0; i < call.operandCount(); ++i) {
            Object operand = call.operand(i);
            if (((SqlNode)operand).getKind() == SqlKind.ARGUMENT_ASSIGNMENT) {
                List<SqlNode> operandList = ((SqlCall)operand).getOperandList();
                String name = ((SqlIdentifier)operandList.get(1)).getSimple();
                int formalIndex = function.getParamNames().indexOf(name);
                if (formalIndex < 0) {
                    return false;
                }
                coerced = this.coerceOperandType(scope, (SqlCall)operand, 0, paramTypes.get(formalIndex)) || coerced;
                continue;
            }
            coerced = this.coerceOperandType(scope, call, i, paramTypes.get(i)) || coerced;
        }
        return coerced;
    }
}

