package org.apache.hadoop.hive.ql.optimizer.calcite.rules;

import com.google.common.base.Function;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.hadoop.hive.ql.optimizer.calcite.CalciteSemanticException;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil;
import org.apache.hadoop.hive.ql.optimizer.calcite.TraitsUtil;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveFilter;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveIntersect;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveProject;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveRelNode;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveTableFunctionScan;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveUnion;
import org.apache.hadoop.hive.ql.optimizer.calcite.translator.SqlFunctionConverter;
import org.apache.hadoop.hive.ql.optimizer.calcite.translator.TypeConverter;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:WEB-INF/lib/hive-exec-2.3.7-mapr-2101.jar:org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveIntersectRewriteRule.class */
public class HiveIntersectRewriteRule extends RelOptRule {
    public static final HiveIntersectRewriteRule INSTANCE = new HiveIntersectRewriteRule();
    protected static final Logger LOG = LoggerFactory.getLogger((Class<?>) HiveIntersectRewriteRule.class);

    private HiveIntersectRewriteRule() {
        super(operand(HiveIntersect.class, any()));
    }

    @Override // org.apache.calcite.plan.RelOptRule
    public void onMatch(RelOptRuleCall relOptRuleCall) {
        HiveIntersect hiveIntersect = (HiveIntersect) relOptRuleCall.rel(0);
        RelOptCluster cluster = hiveIntersect.getCluster();
        RexBuilder rexBuilder = cluster.getRexBuilder();
        int size = hiveIntersect.getInputs().size();
        ImmutableList.Builder builder = new ImmutableList.Builder();
        for (int i = 0; i < size; i++) {
            RelNode relNode = hiveIntersect.getInputs().get(i);
            ArrayList newArrayList = Lists.newArrayList();
            ArrayList newArrayList2 = Lists.newArrayList();
            for (int i2 = 0; i2 < relNode.getRowType().getFieldList().size(); i2++) {
                newArrayList.add(rexBuilder.makeInputRef(relNode, i2));
                newArrayList2.add(Integer.valueOf(i2));
            }
            newArrayList.add(rexBuilder.makeBigintLiteral(new BigDecimal(1)));
            try {
                HiveProject create = HiveProject.create(relNode, newArrayList, null);
                ImmutableBitSet of = ImmutableBitSet.of(newArrayList2);
                ArrayList newArrayList3 = Lists.newArrayList();
                newArrayList3.add(HiveCalciteUtil.createSingleArgAggCall("count", cluster, TypeInfoFactory.longTypeInfo, Integer.valueOf(relNode.getRowType().getFieldList().size()), TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory())));
                builder.add((ImmutableList.Builder) new HiveAggregate(cluster, cluster.traitSetOf(HiveRelNode.CONVENTION), create, false, of, null, newArrayList3));
            } catch (CalciteSemanticException e) {
                LOG.debug(e.toString());
                throw new RuntimeException(e);
            }
        }
        HiveUnion hiveUnion = new HiveUnion(cluster, TraitsUtil.getDefaultTraitSet(cluster), builder.build());
        ArrayList newArrayList4 = Lists.newArrayList();
        int size2 = hiveUnion.getRowType().getFieldList().size() - 1;
        for (int i3 = 0; i3 < hiveUnion.getRowType().getFieldList().size(); i3++) {
            if (i3 != size2) {
                newArrayList4.add(Integer.valueOf(i3));
            }
        }
        ArrayList newArrayList5 = Lists.newArrayList();
        RelDataType convert = TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory());
        newArrayList5.add(HiveCalciteUtil.createSingleArgAggCall("count", cluster, TypeInfoFactory.longTypeInfo, Integer.valueOf(size2), convert));
        if (hiveIntersect.all) {
            newArrayList5.add(HiveCalciteUtil.createSingleArgAggCall("min", cluster, TypeInfoFactory.longTypeInfo, Integer.valueOf(size2), convert));
        }
        HiveAggregate hiveAggregate = new HiveAggregate(cluster, cluster.traitSetOf(HiveRelNode.CONVENTION), hiveUnion, false, ImmutableBitSet.of(newArrayList4), null, newArrayList5);
        ArrayList arrayList = new ArrayList();
        RexInputRef makeInputRef = rexBuilder.makeInputRef(hiveAggregate, size2);
        RexLiteral makeBigintLiteral = rexBuilder.makeBigintLiteral(new BigDecimal(size));
        arrayList.add(makeInputRef);
        arrayList.add(makeBigintLiteral);
        ImmutableList.Builder builder2 = new ImmutableList.Builder();
        builder2.add((ImmutableList.Builder) TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory()));
        builder2.add((ImmutableList.Builder) TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory()));
        try {
            HiveFilter hiveFilter = new HiveFilter(cluster, cluster.traitSetOf(HiveRelNode.CONVENTION), hiveAggregate, rexBuilder.makeCall(SqlFunctionConverter.getCalciteFn("=", builder2.build(), TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory()), true), arrayList));
            if (!hiveIntersect.all) {
                HashSet hashSet = new HashSet();
                hashSet.add(Integer.valueOf(hiveFilter.getRowType().getFieldList().size() - 1));
                try {
                    relOptRuleCall.transformTo(HiveCalciteUtil.createProjectWithoutColumn(hiveFilter, hashSet));
                    return;
                } catch (CalciteSemanticException e2) {
                    LOG.debug(e2.toString());
                    throw new RuntimeException(e2);
                }
            }
            List transform = Lists.transform(hiveFilter.getRowType().getFieldList(), new Function<RelDataTypeField, RexNode>() { // from class: org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveIntersectRewriteRule.1
                @Override // com.google.common.base.Function, java.util.function.Function
                public RexNode apply(RelDataTypeField relDataTypeField) {
                    return new RexInputRef(relDataTypeField.getIndex(), relDataTypeField.getType());
                }
            });
            ArrayList arrayList2 = new ArrayList();
            arrayList2.add((RexNode) transform.get(transform.size() - 1));
            for (int i4 = 0; i4 < transform.size() - 2; i4++) {
                arrayList2.add((RexNode) transform.get(i4));
            }
            try {
                HiveTableFunctionScan createUDTFForSetOp = HiveCalciteUtil.createUDTFForSetOp(cluster, HiveProject.create(hiveFilter, arrayList2, null));
                HashSet hashSet2 = new HashSet();
                hashSet2.add(0);
                relOptRuleCall.transformTo(HiveCalciteUtil.createProjectWithoutColumn(createUDTFForSetOp, hashSet2));
            } catch (SemanticException e3) {
                LOG.debug(e3.toString());
                throw new RuntimeException(e3);
            }
        } catch (CalciteSemanticException e4) {
            LOG.debug(e4.toString());
            throw new RuntimeException(e4);
        }
    }
}
