package org.apache.hadoop.hive.ql.optimizer.calcite.rules;

import hive.com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.tools.RelBuilder;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate;

/* loaded from: input_file:org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveAggregateReduceRule.class */
public class HiveAggregateReduceRule extends RelOptRule {
    public static final HiveAggregateReduceRule INSTANCE = new HiveAggregateReduceRule();

    private HiveAggregateReduceRule() {
        super(operand(HiveAggregate.class, any()), HiveRelFactories.HIVE_BUILDER, (String) null);
    }

    public void onMatch(RelOptRuleCall relOptRuleCall) {
        RelBuilder builder = relOptRuleCall.builder();
        Aggregate rel = relOptRuleCall.rel(0);
        RexBuilder rexBuilder = rel.getCluster().getRexBuilder();
        boolean z = false;
        boolean z2 = true;
        HashMap hashMap = new HashMap();
        ArrayList arrayList = new ArrayList();
        List aggCallList = rel.getAggCallList();
        ArrayList arrayList2 = new ArrayList(aggCallList.size());
        int groupCount = rel.getGroupCount() + rel.getIndicatorCount();
        for (int i = 0; i < aggCallList.size(); i++) {
            AggregateCall aggregateCall = (AggregateCall) aggCallList.get(i);
            if (aggregateCall.getAggregation().getKind() == SqlKind.COUNT && !aggregateCall.isDistinct()) {
                List argList = aggregateCall.getArgList();
                ArrayList arrayList3 = new ArrayList(argList.size());
                Iterator it = argList.iterator();
                while (it.hasNext()) {
                    int intValue = ((Integer) it.next()).intValue();
                    if (((RelDataTypeField) rel.getInput().getRowType().getFieldList().get(intValue)).getType().isNullable()) {
                        arrayList3.add(Integer.valueOf(intValue));
                    }
                }
                if (arrayList3.size() != argList.size()) {
                    aggregateCall = aggregateCall.copy(arrayList3, aggregateCall.filterArg);
                    z = true;
                }
            }
            Integer num = (Integer) hashMap.get(aggregateCall);
            if (num == null) {
                arrayList2.add(aggregateCall);
                int i2 = groupCount;
                groupCount++;
                num = Integer.valueOf(i2);
                hashMap.put(aggregateCall, num);
            } else {
                z = true;
                z2 = false;
            }
            arrayList.add(num);
        }
        if (z) {
            Aggregate copy = rel.copy(rel.getTraitSet(), rel.getInput(), rel.indicator, rel.getGroupSet(), rel.getGroupSets(), arrayList2);
            if (z2) {
                relOptRuleCall.transformTo(copy);
                return;
            }
            int groupCount2 = rel.getGroupCount() + rel.getIndicatorCount();
            ArrayList newArrayList = Lists.newArrayList();
            for (int i3 = 0; i3 < groupCount2; i3++) {
                newArrayList.add(rexBuilder.makeInputRef(((RelDataTypeField) rel.getRowType().getFieldList().get(i3)).getType(), i3));
            }
            for (int i4 = groupCount2; i4 < rel.getRowType().getFieldCount(); i4++) {
                newArrayList.add(rexBuilder.makeInputRef(((RelDataTypeField) rel.getRowType().getFieldList().get(i4)).getType(), ((Integer) arrayList.get(i4 - groupCount2)).intValue()));
            }
            relOptRuleCall.transformTo(builder.push(copy).project(newArrayList).build());
        }
    }
}
