package org.apache.hadoop.hive.ql.optimizer.calcite.rules;

import com.google.common.base.Function;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.math.IntMath;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.metadata.RelColumnOrigin;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.Util;
import org.apache.hadoop.hive.ql.optimizer.calcite.CalciteSemanticException;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveGroupingID;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveProject;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveRelNode;
import org.apache.hadoop.hive.ql.optimizer.calcite.translator.TypeConverter;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveExpandDistinctAggregatesRule.class */
public final class HiveExpandDistinctAggregatesRule extends RelOptRule {
    private static RelFactories.ProjectFactory projFactory;
    RelOptCluster cluster;
    RexBuilder rexBuilder;
    public static final HiveExpandDistinctAggregatesRule INSTANCE = new HiveExpandDistinctAggregatesRule(HiveAggregate.class, HiveRelFactories.HIVE_PROJECT_FACTORY);
    protected static final Logger LOG = LoggerFactory.getLogger(HiveExpandDistinctAggregatesRule.class);

    public HiveExpandDistinctAggregatesRule(Class<? extends Aggregate> cls, RelFactories.ProjectFactory projectFactory) {
        super(operand(cls, any()));
        this.cluster = null;
        this.rexBuilder = null;
        projFactory = projectFactory;
    }

    public void onMatch(RelOptRuleCall relOptRuleCall) {
        Aggregate rel = relOptRuleCall.rel(0);
        int numCountDistinctCall = getNumCountDistinctCall(rel);
        if (numCountDistinctCall == 0) {
            return;
        }
        int i = 0;
        List<List<Integer>> arrayList = new ArrayList<>();
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        HashSet hashSet = new HashSet();
        for (AggregateCall aggregateCall : rel.getAggCallList()) {
            if (aggregateCall.isDistinct()) {
                ArrayList arrayList2 = new ArrayList();
                for (Integer num : aggregateCall.getArgList()) {
                    arrayList2.add(num);
                    hashSet.add(num);
                }
                arrayList.add(arrayList2);
                linkedHashSet.add(arrayList2);
            } else {
                i++;
            }
        }
        Util.permAssert(linkedHashSet.size() > 0, "containsDistinctCall lied");
        if (numCountDistinctCall > 1 && numCountDistinctCall == rel.getAggCallList().size() && rel.getGroupSet().isEmpty()) {
            LOG.debug("Trigger countDistinct rewrite. numCountDistinct is " + numCountDistinctCall);
            this.cluster = rel.getCluster();
            this.rexBuilder = this.cluster.getRexBuilder();
            List<Integer> arrayList3 = new ArrayList<>();
            arrayList3.addAll(hashSet);
            Collections.sort(arrayList3);
            try {
                relOptRuleCall.transformTo(convert(rel, arrayList, arrayList3));
                return;
            } catch (CalciteSemanticException e) {
                LOG.debug(e.toString());
                throw new RuntimeException(e);
            }
        }
        if (i == 0 && linkedHashSet.size() == 1) {
            Iterator it = ((List) linkedHashSet.iterator().next()).iterator();
            while (it.hasNext()) {
                Set<RelColumnOrigin> columnOrigins = RelMetadataQuery.instance().getColumnOrigins(rel, ((Integer) it.next()).intValue());
                if (null != columnOrigins) {
                    for (RelColumnOrigin relColumnOrigin : columnOrigins) {
                        if (relColumnOrigin.getOriginTable().getPartColInfoMap().containsKey(Integer.valueOf(relColumnOrigin.getOriginColumnOrdinal()))) {
                            return;
                        }
                    }
                }
            }
            relOptRuleCall.transformTo(convertMonopole(rel, (List) linkedHashSet.iterator().next()));
        }
    }

    private RelNode convert(Aggregate aggregate, List<List<Integer>> list, List<Integer> list2) throws CalciteSemanticException {
        HashMap hashMap = new HashMap();
        ArrayList arrayList = new ArrayList();
        return createCount(createGroupingSets(aggregate, list, arrayList, hashMap, list2), list, arrayList, hashMap, list2);
    }

    private int getGroupingIdValue(List<Integer> list, List<Integer> list2, int i) {
        int pow = IntMath.pow(2, i) - 1;
        Iterator<Integer> it = list.iterator();
        while (it.hasNext()) {
            pow &= (1 << ((i - list2.indexOf(Integer.valueOf(it.next().intValue()))) - 1)) ^ (-1);
        }
        return pow;
    }

    private RelNode createCount(Aggregate aggregate, List<List<Integer>> list, List<List<Integer>> list2, Map<Integer, Integer> map, List<Integer> list3) throws CalciteSemanticException {
        List transform = Lists.transform(aggregate.getRowType().getFieldList(), new Function<RelDataTypeField, RexNode>() { // from class: org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveExpandDistinctAggregatesRule.1
            @Override // com.google.common.base.Function
            public RexNode apply(RelDataTypeField relDataTypeField) {
                return new RexInputRef(relDataTypeField.getIndex(), relDataTypeField.getType());
            }
        });
        ArrayList newArrayList = Lists.newArrayList();
        for (List<Integer> list4 : list2) {
            RexNode makeCall = this.rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, new RexNode[]{(RexNode) transform.get(transform.size() - 1), this.rexBuilder.makeExactLiteral(new BigDecimal(getGroupingIdValue(list4, list3, aggregate.getGroupCount())))});
            if (list4.size() == 1) {
                makeCall = this.rexBuilder.makeCall(SqlStdOperatorTable.AND, new RexNode[]{makeCall, this.rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, new RexNode[]{(RexNode) transform.get(list4.get(0).intValue())})});
            }
            newArrayList.add(this.rexBuilder.makeCall(SqlStdOperatorTable.CASE, new RexNode[]{makeCall, this.rexBuilder.makeExactLiteral(BigDecimal.ONE), this.rexBuilder.constantNull()}));
        }
        HiveProject create = HiveProject.create(aggregate, newArrayList, null);
        ArrayList newArrayList2 = Lists.newArrayList();
        RelDataType convert = TypeConverter.convert(TypeInfoFactory.longTypeInfo, this.cluster.getTypeFactory());
        for (int i = 0; i < list2.size(); i++) {
            newArrayList2.add(HiveCalciteUtil.createSingleArgAggCall("count", this.cluster, TypeInfoFactory.longTypeInfo, Integer.valueOf(i), convert));
        }
        HiveAggregate hiveAggregate = new HiveAggregate(this.cluster, this.cluster.traitSetOf(HiveRelNode.CONVENTION), create, ImmutableBitSet.of(), null, newArrayList2);
        if (map.isEmpty()) {
            return hiveAggregate;
        }
        List transform2 = Lists.transform(hiveAggregate.getRowType().getFieldList(), new Function<RelDataTypeField, RexNode>() { // from class: org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveExpandDistinctAggregatesRule.2
            @Override // com.google.common.base.Function
            public RexNode apply(RelDataTypeField relDataTypeField) {
                return new RexInputRef(relDataTypeField.getIndex(), relDataTypeField.getType());
            }
        });
        ArrayList newArrayList3 = Lists.newArrayList();
        int i2 = 0;
        for (int i3 = 0; i3 < list.size(); i3++) {
            if (map.containsKey(Integer.valueOf(i3))) {
                newArrayList3.add(transform2.get(map.get(Integer.valueOf(i3)).intValue()));
            } else {
                int i4 = i2;
                i2++;
                newArrayList3.add(transform2.get(i4));
            }
        }
        return HiveProject.create(hiveAggregate, newArrayList3, null);
    }

    private Aggregate createGroupingSets(Aggregate aggregate, List<List<Integer>> list, List<List<Integer>> list2, Map<Integer, Integer> map, List<Integer> list3) {
        ImmutableBitSet of = ImmutableBitSet.of(list3);
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list.size(); i++) {
            List<Integer> list4 = list.get(i);
            ImmutableBitSet of2 = ImmutableBitSet.of(list4);
            int indexOf = arrayList.indexOf(of2);
            if (indexOf == -1) {
                arrayList.add(of2);
                list2.add(list4);
            } else {
                map.put(Integer.valueOf(i), Integer.valueOf(indexOf));
            }
        }
        Collections.sort(arrayList, ImmutableBitSet.COMPARATOR);
        ArrayList arrayList2 = new ArrayList();
        arrayList2.add(AggregateCall.create(HiveGroupingID.INSTANCE, false, new ImmutableList.Builder().build(), -1, this.cluster.getTypeFactory().createSqlType(SqlTypeName.INTEGER), HiveGroupingID.INSTANCE.getName()));
        return new HiveAggregate(this.cluster, this.cluster.traitSetOf(HiveRelNode.CONVENTION), aggregate.getInput(), of, arrayList, arrayList2);
    }

    private int getNumCountDistinctCall(Aggregate aggregate) {
        int i = 0;
        for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
            if (aggregateCall.isDistinct() && aggregateCall.getAggregation().getName().equalsIgnoreCase("count")) {
                i++;
            }
        }
        return i;
    }

    private RelNode convertMonopole(Aggregate aggregate, List<Integer> list) {
        HashMap hashMap = new HashMap();
        Aggregate createSelectDistinct = createSelectDistinct(aggregate, list, hashMap);
        ArrayList newArrayList = Lists.newArrayList(aggregate.getAggCallList());
        rewriteAggCalls(newArrayList, list, hashMap);
        int cardinality = aggregate.getGroupSet().cardinality();
        RelTraitSet traitSet = aggregate.getTraitSet();
        aggregate.getClass();
        return aggregate.copy(traitSet, createSelectDistinct, false, ImmutableBitSet.range(cardinality), (List) null, newArrayList);
    }

    private static void rewriteAggCalls(List<AggregateCall> list, List<Integer> list2, Map<Integer, Integer> map) {
        for (int i = 0; i < list.size(); i++) {
            AggregateCall aggregateCall = list.get(i);
            if (aggregateCall.isDistinct() && aggregateCall.getArgList().equals(list2)) {
                int size = aggregateCall.getArgList().size();
                ArrayList arrayList = new ArrayList(size);
                for (int i2 = 0; i2 < size; i2++) {
                    arrayList.add(map.get((Integer) aggregateCall.getArgList().get(i2)));
                }
                list.set(i, new AggregateCall(aggregateCall.getAggregation(), false, arrayList, aggregateCall.getType(), aggregateCall.getName()));
            }
        }
    }

    private static Aggregate createSelectDistinct(Aggregate aggregate, List<Integer> list, Map<Integer, Integer> map) {
        ArrayList arrayList = new ArrayList();
        RelNode input = aggregate.getInput();
        List fieldList = input.getRowType().getFieldList();
        Iterator it = aggregate.getGroupSet().iterator();
        while (it.hasNext()) {
            int intValue = ((Integer) it.next()).intValue();
            map.put(Integer.valueOf(intValue), Integer.valueOf(arrayList.size()));
            arrayList.add(RexInputRef.of2(intValue, fieldList));
        }
        for (Integer num : list) {
            if (map.get(num) == null) {
                map.put(num, Integer.valueOf(arrayList.size()));
                arrayList.add(RexInputRef.of2(num.intValue(), fieldList));
            }
        }
        return aggregate.copy(aggregate.getTraitSet(), projFactory.createProject(input, Pair.left(arrayList), Pair.right(arrayList)), false, ImmutableBitSet.range(arrayList.size()), (List) null, ImmutableList.of());
    }
}
